spatialcheckpoint 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatialcheckpoint/__init__.py +39 -0
- spatialcheckpoint/analysis/__init__.py +1 -0
- spatialcheckpoint/analysis/colocalization.py +516 -0
- spatialcheckpoint/analysis/domain_annotation.py +472 -0
- spatialcheckpoint/analysis/gradient.py +211 -0
- spatialcheckpoint/analysis/spatial_expression.py +320 -0
- spatialcheckpoint/analysis/spatial_features.py +673 -0
- spatialcheckpoint/cli.py +270 -0
- spatialcheckpoint/configs/checkpoint_panel.yaml +80 -0
- spatialcheckpoint/configs/spatial_datasets.yaml +204 -0
- spatialcheckpoint/data/__init__.py +1 -0
- spatialcheckpoint/data/download.py +702 -0
- spatialcheckpoint/data/loader.py +199 -0
- spatialcheckpoint/data/preprocess.py +506 -0
- spatialcheckpoint/model/__init__.py +1 -0
- spatialcheckpoint/model/archetype_discovery.py +545 -0
- spatialcheckpoint/model/classifier.py +726 -0
- spatialcheckpoint/model/explainer.py +136 -0
- spatialcheckpoint/model/trainer.py +136 -0
- spatialcheckpoint/utils/__init__.py +1 -0
- spatialcheckpoint/utils/gene_sets.py +171 -0
- spatialcheckpoint/utils/metrics.py +233 -0
- spatialcheckpoint/validation/__init__.py +1 -0
- spatialcheckpoint/validation/bulk_mapping.py +90 -0
- spatialcheckpoint/validation/clinical_association.py +825 -0
- spatialcheckpoint/visualization/__init__.py +1 -0
- spatialcheckpoint/visualization/paper_figures.py +1489 -0
- spatialcheckpoint/visualization/spatial_plots.py +227 -0
- spatialcheckpoint-0.1.0.dist-info/METADATA +462 -0
- spatialcheckpoint-0.1.0.dist-info/RECORD +33 -0
- spatialcheckpoint-0.1.0.dist-info/WHEEL +5 -0
- spatialcheckpoint-0.1.0.dist-info/entry_points.txt +2 -0
- spatialcheckpoint-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""SpatialCheckpoint: Spatial heterogeneity profiling of immune checkpoints."""
|
|
2
|
+
__version__ = "0.1.0"
|
|
3
|
+
|
|
4
|
+
from spatialcheckpoint.data.preprocess import SpatialDataPreprocessor
|
|
5
|
+
from spatialcheckpoint.data.loader import SpatialDataLoader
|
|
6
|
+
from spatialcheckpoint.analysis.spatial_expression import SpatialCheckpointProfiler
|
|
7
|
+
from spatialcheckpoint.analysis.spatial_features import SpatialFeatureEngineer
|
|
8
|
+
from spatialcheckpoint.analysis.colocalization import CheckpointColocalizationAnalyzer
|
|
9
|
+
from spatialcheckpoint.model.archetype_discovery import SpatialArchetypeDiscovery
|
|
10
|
+
from spatialcheckpoint.model.classifier import SpatialArchetypeClassifier
|
|
11
|
+
from spatialcheckpoint.model.trainer import ArchetypeModelTrainer
|
|
12
|
+
from spatialcheckpoint.model.explainer import ArchetypeExplainer
|
|
13
|
+
from spatialcheckpoint.utils.gene_sets import (
|
|
14
|
+
get_all_checkpoint_genes,
|
|
15
|
+
get_category_genes,
|
|
16
|
+
get_immune_cell_markers,
|
|
17
|
+
get_ligand_receptor_pairs,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"__version__",
|
|
22
|
+
# Data
|
|
23
|
+
"SpatialDataPreprocessor",
|
|
24
|
+
"SpatialDataLoader",
|
|
25
|
+
# Analysis
|
|
26
|
+
"SpatialCheckpointProfiler",
|
|
27
|
+
"SpatialFeatureEngineer",
|
|
28
|
+
"CheckpointColocalizationAnalyzer",
|
|
29
|
+
# Model
|
|
30
|
+
"SpatialArchetypeDiscovery",
|
|
31
|
+
"SpatialArchetypeClassifier",
|
|
32
|
+
"ArchetypeModelTrainer",
|
|
33
|
+
"ArchetypeExplainer",
|
|
34
|
+
# Gene sets
|
|
35
|
+
"get_all_checkpoint_genes",
|
|
36
|
+
"get_category_genes",
|
|
37
|
+
"get_immune_cell_markers",
|
|
38
|
+
"get_ligand_receptor_pairs",
|
|
39
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Analysis module."""
|
|
@@ -0,0 +1,516 @@
|
|
|
1
|
+
"""Spatial co-localization analysis of immune checkpoints."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import scipy.sparse as sp
|
|
10
|
+
from scipy import stats
|
|
11
|
+
from scipy.spatial import cKDTree
|
|
12
|
+
import scanpy as sc
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import squidpy as sq
|
|
16
|
+
HAS_SQUIDPY = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
HAS_SQUIDPY = False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CheckpointColocalizationAnalyzer:
|
|
22
|
+
"""Spatial co-localization analysis for immune checkpoints in Visium data."""
|
|
23
|
+
|
|
24
|
+
# Known ligand-receptor pairs for checkpoints
|
|
25
|
+
LR_PAIRS = [
|
|
26
|
+
('CD274', 'PDCD1'), # PD-L1 / PD-1
|
|
27
|
+
('PDCD1LG2', 'PDCD1'), # PD-L2 / PD-1
|
|
28
|
+
('LGALS9', 'HAVCR2'), # Galectin-9 / TIM-3
|
|
29
|
+
('PVR', 'TIGIT'), # CD155 / TIGIT
|
|
30
|
+
('NECTIN2', 'TIGIT'), # CD112 / TIGIT
|
|
31
|
+
('FGL1', 'LAG3'), # FGL1 / LAG-3
|
|
32
|
+
('CEACAM1', 'LAG3'), # CEACAM1 / LAG-3
|
|
33
|
+
('CD47', 'SIRPA'), # CD47 / SIRPα
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
def __init__(self, adata: sc.AnnData, checkpoint_genes: List[str],
|
|
37
|
+
immune_markers: Optional[dict] = None):
|
|
38
|
+
self.adata = adata
|
|
39
|
+
self.cp_genes = [g for g in checkpoint_genes if g in adata.var_names]
|
|
40
|
+
self.immune_markers = immune_markers or {}
|
|
41
|
+
self._coords = None
|
|
42
|
+
|
|
43
|
+
# ------------------------------------------------------------------
|
|
44
|
+
# Helpers
|
|
45
|
+
# ------------------------------------------------------------------
|
|
46
|
+
|
|
47
|
+
def _get_coords(self) -> np.ndarray:
|
|
48
|
+
"""Get spot spatial coordinates (n_spots, 2)."""
|
|
49
|
+
if self._coords is None:
|
|
50
|
+
self._coords = np.array(self.adata.obsm['spatial'], dtype=float)
|
|
51
|
+
return self._coords
|
|
52
|
+
|
|
53
|
+
def _get_expr(self, genes: list[str]) -> np.ndarray:
|
|
54
|
+
"""Get dense expression for given genes (n_spots, n_genes).
|
|
55
|
+
Handles sparse adata.X."""
|
|
56
|
+
indices = [self.adata.var_names.get_loc(g) for g in genes]
|
|
57
|
+
X = self.adata.X[:, indices]
|
|
58
|
+
if sp.issparse(X):
|
|
59
|
+
X = X.toarray()
|
|
60
|
+
return np.array(X, dtype=float)
|
|
61
|
+
|
|
62
|
+
# ------------------------------------------------------------------
|
|
63
|
+
# Public methods
|
|
64
|
+
# ------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
def pairwise_cooccurrence(self, interval: int = 50,
|
|
67
|
+
n_intervals: int = 10) -> dict:
|
|
68
|
+
"""Pairwise checkpoint spatial co-occurrence across distance intervals.
|
|
69
|
+
|
|
70
|
+
For each pair of checkpoint genes (g1, g2):
|
|
71
|
+
- Define high-expression spots: top 25% expression
|
|
72
|
+
- For distances d = 0, interval, 2*interval, ... n_intervals*interval:
|
|
73
|
+
compute fraction of g2-high spots within distance d of g1-high spots
|
|
74
|
+
vs random expectation
|
|
75
|
+
- co_occurrence_ratio[d] = observed / expected
|
|
76
|
+
|
|
77
|
+
If squidpy available, use sq.gr.co_occurrence instead.
|
|
78
|
+
|
|
79
|
+
Returns: dict mapping (gene1, gene2) -> np.ndarray of co-occurrence ratios
|
|
80
|
+
Only compute for the top-20 most variable checkpoint gene pairs
|
|
81
|
+
(to keep computation tractable).
|
|
82
|
+
"""
|
|
83
|
+
if len(self.cp_genes) < 2:
|
|
84
|
+
return {}
|
|
85
|
+
|
|
86
|
+
if HAS_SQUIDPY:
|
|
87
|
+
return self._pairwise_cooccurrence_squidpy(interval, n_intervals)
|
|
88
|
+
return self._pairwise_cooccurrence_fallback(interval, n_intervals)
|
|
89
|
+
|
|
90
|
+
def _pairwise_cooccurrence_squidpy(self, interval: int, n_intervals: int) -> dict:
|
|
91
|
+
"""Co-occurrence via squidpy."""
|
|
92
|
+
top_genes = self._select_top_variable_genes(n=20)
|
|
93
|
+
if len(top_genes) < 2:
|
|
94
|
+
return {}
|
|
95
|
+
|
|
96
|
+
# Build interval array: squidpy expects an array of distances
|
|
97
|
+
distances = np.arange(0, (n_intervals + 1) * interval, interval, dtype=float)
|
|
98
|
+
distances[0] = 1e-6 # avoid zero
|
|
99
|
+
|
|
100
|
+
adata_sub = self.adata[:, top_genes].copy()
|
|
101
|
+
try:
|
|
102
|
+
sq.gr.co_occurrence(adata_sub, cluster_key=None, genes=top_genes,
|
|
103
|
+
spatial_key='spatial', interval=distances)
|
|
104
|
+
result = {}
|
|
105
|
+
co_occ = adata_sub.uns.get('co_occurrence', {})
|
|
106
|
+
for g1 in top_genes:
|
|
107
|
+
for g2 in top_genes:
|
|
108
|
+
if g1 == g2:
|
|
109
|
+
continue
|
|
110
|
+
key = (g1, g2)
|
|
111
|
+
if key in co_occ:
|
|
112
|
+
result[key] = np.array(co_occ[key], dtype=float)
|
|
113
|
+
return result
|
|
114
|
+
except Exception:
|
|
115
|
+
warnings.warn(
|
|
116
|
+
"squidpy co_occurrence failed; falling back to manual computation.",
|
|
117
|
+
RuntimeWarning,
|
|
118
|
+
)
|
|
119
|
+
return self._pairwise_cooccurrence_fallback(interval, n_intervals)
|
|
120
|
+
|
|
121
|
+
def _pairwise_cooccurrence_fallback(self, interval: int, n_intervals: int) -> dict:
|
|
122
|
+
"""Manual co-occurrence computation."""
|
|
123
|
+
top_genes = self._select_top_variable_genes(n=20)
|
|
124
|
+
if len(top_genes) < 2:
|
|
125
|
+
return {}
|
|
126
|
+
|
|
127
|
+
coords = self._get_coords()
|
|
128
|
+
expr = self._get_expr(top_genes) # (n_spots, n_top)
|
|
129
|
+
gene_index = {g: i for i, g in enumerate(top_genes)}
|
|
130
|
+
n_spots = coords.shape[0]
|
|
131
|
+
|
|
132
|
+
distances = np.arange(0, (n_intervals + 1) * interval, interval, dtype=float)
|
|
133
|
+
|
|
134
|
+
# Build a KDTree for all spots
|
|
135
|
+
tree = cKDTree(coords)
|
|
136
|
+
|
|
137
|
+
# Global fraction of high spots for each gene (baseline / expected)
|
|
138
|
+
threshold_75 = np.percentile(expr, 75, axis=0) # (n_top,)
|
|
139
|
+
high_masks = expr > threshold_75[np.newaxis, :] # (n_spots, n_top)
|
|
140
|
+
global_frac = high_masks.mean(axis=0) # (n_top,)
|
|
141
|
+
|
|
142
|
+
result = {}
|
|
143
|
+
for i, g1 in enumerate(top_genes):
|
|
144
|
+
g1_high = np.where(high_masks[:, i])[0]
|
|
145
|
+
if len(g1_high) == 0:
|
|
146
|
+
continue
|
|
147
|
+
for j, g2 in enumerate(top_genes):
|
|
148
|
+
if i == j:
|
|
149
|
+
continue
|
|
150
|
+
g2_high_set = set(np.where(high_masks[:, j])[0])
|
|
151
|
+
ratios = np.zeros(len(distances))
|
|
152
|
+
for d_idx, d in enumerate(distances):
|
|
153
|
+
if d < 1e-6:
|
|
154
|
+
# Distance = 0: only the spot itself
|
|
155
|
+
observed = np.mean([1.0 if k in g2_high_set else 0.0
|
|
156
|
+
for k in g1_high])
|
|
157
|
+
else:
|
|
158
|
+
# Spots within distance d for each g1-high spot
|
|
159
|
+
neighbors_list = tree.query_ball_point(coords[g1_high], r=d)
|
|
160
|
+
obs_fracs = []
|
|
161
|
+
for spot_idx, nbrs in zip(g1_high, neighbors_list):
|
|
162
|
+
nbrs_excl = [n for n in nbrs if n != spot_idx]
|
|
163
|
+
if len(nbrs_excl) == 0:
|
|
164
|
+
obs_fracs.append(0.0)
|
|
165
|
+
else:
|
|
166
|
+
frac = sum(1 for n in nbrs_excl if n in g2_high_set) / len(nbrs_excl)
|
|
167
|
+
obs_fracs.append(frac)
|
|
168
|
+
observed = float(np.mean(obs_fracs)) if obs_fracs else 0.0
|
|
169
|
+
|
|
170
|
+
expected = global_frac[j]
|
|
171
|
+
ratios[d_idx] = observed / (expected + 1e-9)
|
|
172
|
+
|
|
173
|
+
result[(g1, g2)] = ratios
|
|
174
|
+
|
|
175
|
+
return result
|
|
176
|
+
|
|
177
|
+
def _select_top_variable_genes(self, n: int = 20) -> list:
|
|
178
|
+
"""Select top-n checkpoint genes by expression variance."""
|
|
179
|
+
if len(self.cp_genes) <= n:
|
|
180
|
+
return list(self.cp_genes)
|
|
181
|
+
expr = self._get_expr(self.cp_genes)
|
|
182
|
+
variances = expr.var(axis=0)
|
|
183
|
+
top_idx = np.argsort(variances)[::-1][:n]
|
|
184
|
+
return [self.cp_genes[i] for i in top_idx]
|
|
185
|
+
|
|
186
|
+
# ------------------------------------------------------------------
|
|
187
|
+
|
|
188
|
+
def checkpoint_immune_proximity(self, n_permutations: int = 200,
|
|
189
|
+
top_pct: float = 0.25) -> pd.DataFrame:
|
|
190
|
+
"""Distance from checkpoint-high spots to nearest immune cell spots.
|
|
191
|
+
|
|
192
|
+
For each checkpoint gene:
|
|
193
|
+
1. Identify checkpoint-high spots (top top_pct expression)
|
|
194
|
+
2. For each immune cell type (from immune_markers):
|
|
195
|
+
a. Identify immune spots: ≥1 immune marker expressed
|
|
196
|
+
b. Use KDTree to compute mean distance: checkpoint-high → nearest immune
|
|
197
|
+
c. Permutation test (n_permutations): shuffle checkpoint-high labels,
|
|
198
|
+
compute null distribution of mean distances
|
|
199
|
+
d. z_score = (observed - null_mean) / null_std
|
|
200
|
+
e. p_value = fraction of permutations with smaller distance
|
|
201
|
+
|
|
202
|
+
Returns DataFrame: columns=[gene, immune_cell, mean_dist, expected_dist,
|
|
203
|
+
z_score, p_value]
|
|
204
|
+
Negative z_score = closer than random (co-localized).
|
|
205
|
+
"""
|
|
206
|
+
if not self.cp_genes or not self.immune_markers:
|
|
207
|
+
return pd.DataFrame(columns=['gene', 'immune_cell', 'mean_dist',
|
|
208
|
+
'expected_dist', 'z_score', 'p_value'])
|
|
209
|
+
|
|
210
|
+
coords = self._get_coords()
|
|
211
|
+
n_spots = coords.shape[0]
|
|
212
|
+
rng = np.random.default_rng(42)
|
|
213
|
+
|
|
214
|
+
# Precompute checkpoint expressions
|
|
215
|
+
cp_expr = self._get_expr(self.cp_genes) # (n_spots, n_cp)
|
|
216
|
+
cp_thresholds = np.quantile(cp_expr, 1.0 - top_pct, axis=0) # (n_cp,)
|
|
217
|
+
|
|
218
|
+
# Precompute immune spot masks
|
|
219
|
+
immune_spot_masks = {}
|
|
220
|
+
for cell_type, markers in self.immune_markers.items():
|
|
221
|
+
avail = [m for m in markers if m in self.adata.var_names]
|
|
222
|
+
if not avail:
|
|
223
|
+
continue
|
|
224
|
+
marker_expr = self._get_expr(avail) # (n_spots, n_markers)
|
|
225
|
+
# A spot is "immune" if ≥1 marker expressed (> 0)
|
|
226
|
+
immune_mask = (marker_expr > 0).any(axis=1)
|
|
227
|
+
immune_spot_masks[cell_type] = immune_mask
|
|
228
|
+
|
|
229
|
+
records = []
|
|
230
|
+
for g_idx, gene in enumerate(self.cp_genes):
|
|
231
|
+
cp_high_mask = cp_expr[:, g_idx] > cp_thresholds[g_idx]
|
|
232
|
+
cp_high_spots = np.where(cp_high_mask)[0]
|
|
233
|
+
if len(cp_high_spots) == 0:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
for cell_type, immune_mask in immune_spot_masks.items():
|
|
237
|
+
immune_spots = np.where(immune_mask)[0]
|
|
238
|
+
if len(immune_spots) == 0:
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
immune_tree = cKDTree(coords[immune_spots])
|
|
242
|
+
|
|
243
|
+
# Observed mean distance: checkpoint-high → nearest immune
|
|
244
|
+
dists, _ = immune_tree.query(coords[cp_high_spots], k=1)
|
|
245
|
+
observed_mean = float(np.mean(dists))
|
|
246
|
+
|
|
247
|
+
# Permutation test: shuffle which spots are "checkpoint-high"
|
|
248
|
+
n_cp_high = len(cp_high_spots)
|
|
249
|
+
null_means = np.empty(n_permutations)
|
|
250
|
+
for perm_i in range(n_permutations):
|
|
251
|
+
perm_indices = rng.choice(n_spots, size=n_cp_high, replace=False)
|
|
252
|
+
perm_dists, _ = immune_tree.query(coords[perm_indices], k=1)
|
|
253
|
+
null_means[perm_i] = float(np.mean(perm_dists))
|
|
254
|
+
|
|
255
|
+
null_mean = float(np.mean(null_means))
|
|
256
|
+
null_std = float(np.std(null_means))
|
|
257
|
+
z_score = (observed_mean - null_mean) / (null_std + 1e-9)
|
|
258
|
+
# p_value = fraction of permutations with smaller (more co-localized) distance
|
|
259
|
+
p_value = float(np.mean(null_means <= observed_mean))
|
|
260
|
+
|
|
261
|
+
records.append({
|
|
262
|
+
'gene': gene,
|
|
263
|
+
'immune_cell': cell_type,
|
|
264
|
+
'mean_dist': observed_mean,
|
|
265
|
+
'expected_dist': null_mean,
|
|
266
|
+
'z_score': z_score,
|
|
267
|
+
'p_value': p_value,
|
|
268
|
+
})
|
|
269
|
+
|
|
270
|
+
return pd.DataFrame(records, columns=['gene', 'immune_cell', 'mean_dist',
|
|
271
|
+
'expected_dist', 'z_score', 'p_value'])
|
|
272
|
+
|
|
273
|
+
# ------------------------------------------------------------------
|
|
274
|
+
|
|
275
|
+
def ligand_receptor_spatial_analysis(self) -> pd.DataFrame:
|
|
276
|
+
"""Spatial proximity analysis for known checkpoint LR pairs.
|
|
277
|
+
|
|
278
|
+
For each (ligand, receptor) pair that has both genes in dataset:
|
|
279
|
+
1. Get ligand expression per spot
|
|
280
|
+
2. For each ligand-high spot (top 25%), compute mean receptor expression
|
|
281
|
+
in k=10 nearest neighbors vs k=10 neighbors of random spots
|
|
282
|
+
3. Log2 fold change: LR_score = log2(neighbor_receptor / non_neighbor_receptor + 1e-6)
|
|
283
|
+
4. Permutation test for significance (n=500 permutations)
|
|
284
|
+
|
|
285
|
+
Returns DataFrame: columns=[ligand, receptor, alias, lr_score, pvalue,
|
|
286
|
+
ligand_in_dataset, receptor_in_dataset]
|
|
287
|
+
"""
|
|
288
|
+
var_names_set = set(self.adata.var_names)
|
|
289
|
+
coords = self._get_coords()
|
|
290
|
+
n_spots = coords.shape[0]
|
|
291
|
+
k = min(10, n_spots - 1)
|
|
292
|
+
n_permutations = 500
|
|
293
|
+
rng = np.random.default_rng(42)
|
|
294
|
+
|
|
295
|
+
# Build alias map for LR pairs
|
|
296
|
+
alias_map = {
|
|
297
|
+
('CD274', 'PDCD1'): 'PD-L1/PD-1',
|
|
298
|
+
('PDCD1LG2', 'PDCD1'): 'PD-L2/PD-1',
|
|
299
|
+
('LGALS9', 'HAVCR2'): 'Galectin-9/TIM-3',
|
|
300
|
+
('PVR', 'TIGIT'): 'CD155/TIGIT',
|
|
301
|
+
('NECTIN2', 'TIGIT'): 'CD112/TIGIT',
|
|
302
|
+
('FGL1', 'LAG3'): 'FGL1/LAG-3',
|
|
303
|
+
('CEACAM1', 'LAG3'): 'CEACAM1/LAG-3',
|
|
304
|
+
('CD47', 'SIRPA'): 'CD47/SIRPα',
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
# Build KDTree once for all spots
|
|
308
|
+
tree = cKDTree(coords)
|
|
309
|
+
|
|
310
|
+
records = []
|
|
311
|
+
for ligand, receptor in self.LR_PAIRS:
|
|
312
|
+
lig_in = ligand in var_names_set
|
|
313
|
+
rec_in = receptor in var_names_set
|
|
314
|
+
alias = alias_map.get((ligand, receptor), f'{ligand}/{receptor}')
|
|
315
|
+
|
|
316
|
+
if not (lig_in and rec_in):
|
|
317
|
+
records.append({
|
|
318
|
+
'ligand': ligand,
|
|
319
|
+
'receptor': receptor,
|
|
320
|
+
'alias': alias,
|
|
321
|
+
'lr_score': np.nan,
|
|
322
|
+
'pvalue': np.nan,
|
|
323
|
+
'ligand_in_dataset': lig_in,
|
|
324
|
+
'receptor_in_dataset': rec_in,
|
|
325
|
+
})
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
lig_expr = self._get_expr([ligand])[:, 0] # (n_spots,)
|
|
329
|
+
rec_expr = self._get_expr([receptor])[:, 0] # (n_spots,)
|
|
330
|
+
|
|
331
|
+
# Ligand-high spots: top 25%
|
|
332
|
+
lig_threshold = np.percentile(lig_expr, 75)
|
|
333
|
+
lig_high = np.where(lig_expr > lig_threshold)[0]
|
|
334
|
+
if len(lig_high) == 0:
|
|
335
|
+
records.append({
|
|
336
|
+
'ligand': ligand, 'receptor': receptor, 'alias': alias,
|
|
337
|
+
'lr_score': np.nan, 'pvalue': np.nan,
|
|
338
|
+
'ligand_in_dataset': lig_in, 'receptor_in_dataset': rec_in,
|
|
339
|
+
})
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
# For each ligand-high spot, find k nearest neighbors
|
|
343
|
+
_, nbr_indices = tree.query(coords[lig_high], k=k + 1)
|
|
344
|
+
# Exclude the query spot itself (first column will be the spot itself if it's in tree)
|
|
345
|
+
# tree.query returns the spot and its neighbors; exclude self
|
|
346
|
+
nbr_indices = nbr_indices[:, 1:] # (n_lig_high, k)
|
|
347
|
+
|
|
348
|
+
# Mean receptor expression in neighbors of ligand-high spots
|
|
349
|
+
neighbor_rec = rec_expr[nbr_indices].mean() # scalar
|
|
350
|
+
|
|
351
|
+
# Mean receptor expression in neighbors of random (non-ligand-high) spots
|
|
352
|
+
non_lig_high = np.where(lig_expr <= lig_threshold)[0]
|
|
353
|
+
if len(non_lig_high) >= len(lig_high):
|
|
354
|
+
random_spots = rng.choice(non_lig_high, size=len(lig_high), replace=False)
|
|
355
|
+
else:
|
|
356
|
+
random_spots = rng.choice(n_spots, size=len(lig_high), replace=False)
|
|
357
|
+
|
|
358
|
+
_, rand_nbr_indices = tree.query(coords[random_spots], k=k + 1)
|
|
359
|
+
rand_nbr_indices = rand_nbr_indices[:, 1:]
|
|
360
|
+
random_rec = rec_expr[rand_nbr_indices].mean()
|
|
361
|
+
|
|
362
|
+
# LR score: log2 fold change
|
|
363
|
+
lr_score = float(np.log2((neighbor_rec + 1e-6) / (random_rec + 1e-6)))
|
|
364
|
+
|
|
365
|
+
# Permutation test: shuffle ligand-high labels
|
|
366
|
+
null_scores = np.empty(n_permutations)
|
|
367
|
+
for perm_i in range(n_permutations):
|
|
368
|
+
perm_lig_high = rng.choice(n_spots, size=len(lig_high), replace=False)
|
|
369
|
+
_, perm_nbr_idx = tree.query(coords[perm_lig_high], k=k + 1)
|
|
370
|
+
perm_nbr_idx = perm_nbr_idx[:, 1:]
|
|
371
|
+
perm_neighbor_rec = rec_expr[perm_nbr_idx].mean()
|
|
372
|
+
|
|
373
|
+
perm_rand = rng.choice(n_spots, size=len(lig_high), replace=False)
|
|
374
|
+
_, perm_rand_idx = tree.query(coords[perm_rand], k=k + 1)
|
|
375
|
+
perm_rand_idx = perm_rand_idx[:, 1:]
|
|
376
|
+
perm_random_rec = rec_expr[perm_rand_idx].mean()
|
|
377
|
+
|
|
378
|
+
null_scores[perm_i] = float(
|
|
379
|
+
np.log2((perm_neighbor_rec + 1e-6) / (perm_random_rec + 1e-6))
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Two-sided p-value: fraction of permutations with |score| >= |observed|
|
|
383
|
+
pvalue = float(np.mean(np.abs(null_scores) >= np.abs(lr_score)))
|
|
384
|
+
|
|
385
|
+
records.append({
|
|
386
|
+
'ligand': ligand,
|
|
387
|
+
'receptor': receptor,
|
|
388
|
+
'alias': alias,
|
|
389
|
+
'lr_score': lr_score,
|
|
390
|
+
'pvalue': pvalue,
|
|
391
|
+
'ligand_in_dataset': lig_in,
|
|
392
|
+
'receptor_in_dataset': rec_in,
|
|
393
|
+
})
|
|
394
|
+
|
|
395
|
+
return pd.DataFrame(records, columns=['ligand', 'receptor', 'alias',
|
|
396
|
+
'lr_score', 'pvalue',
|
|
397
|
+
'ligand_in_dataset', 'receptor_in_dataset'])
|
|
398
|
+
|
|
399
|
+
# ------------------------------------------------------------------
|
|
400
|
+
|
|
401
|
+
def spatial_niche_analysis(self, n_niches: int = 5,
|
|
402
|
+
k_neighbors: int = 15,
|
|
403
|
+
random_state: int = 42) -> pd.DataFrame:
|
|
404
|
+
"""Discover spatial niches based on local checkpoint + immune expression.
|
|
405
|
+
|
|
406
|
+
Algorithm:
|
|
407
|
+
1. For each spot, compute neighborhood profile:
|
|
408
|
+
- k=15 nearest neighbor mean expression for all cp_genes + immune markers
|
|
409
|
+
2. Cluster neighborhood profiles using k-means (n_niches clusters)
|
|
410
|
+
with fixed random_state for reproducibility
|
|
411
|
+
3. Store niche labels in adata.obs['spatial_niche']
|
|
412
|
+
4. Characterize each niche: top checkpoint and immune marker genes
|
|
413
|
+
|
|
414
|
+
Returns DataFrame: index=niche_id, columns:
|
|
415
|
+
- top_checkpoints: top 3 checkpoint genes (by mean expression in niche)
|
|
416
|
+
- top_immune: top 3 immune markers
|
|
417
|
+
- n_spots: number of spots in niche
|
|
418
|
+
- pct_tumor_core, pct_stroma, pct_immune_enriched: region composition
|
|
419
|
+
"""
|
|
420
|
+
from sklearn.cluster import KMeans
|
|
421
|
+
|
|
422
|
+
coords = self._get_coords()
|
|
423
|
+
n_spots = coords.shape[0]
|
|
424
|
+
k = min(k_neighbors, n_spots - 1)
|
|
425
|
+
|
|
426
|
+
# Collect all genes: checkpoint + all immune marker genes
|
|
427
|
+
all_immune_genes = []
|
|
428
|
+
for markers in self.immune_markers.values():
|
|
429
|
+
all_immune_genes.extend(markers)
|
|
430
|
+
all_immune_genes = list(dict.fromkeys(all_immune_genes)) # unique, order-preserving
|
|
431
|
+
avail_immune = [g for g in all_immune_genes if g in self.adata.var_names]
|
|
432
|
+
|
|
433
|
+
feature_genes = list(self.cp_genes) + avail_immune
|
|
434
|
+
if not feature_genes:
|
|
435
|
+
warnings.warn("No feature genes available for niche analysis.", RuntimeWarning)
|
|
436
|
+
self.adata.obs['spatial_niche'] = 0
|
|
437
|
+
return pd.DataFrame()
|
|
438
|
+
|
|
439
|
+
expr = self._get_expr(feature_genes) # (n_spots, n_features)
|
|
440
|
+
|
|
441
|
+
# Build KDTree and compute neighborhood mean profiles
|
|
442
|
+
tree = cKDTree(coords)
|
|
443
|
+
# Query k+1 neighbors (includes the spot itself)
|
|
444
|
+
_, nbr_indices = tree.query(coords, k=k + 1) # (n_spots, k+1)
|
|
445
|
+
# Exclude self (first index is always the query point in cKDTree)
|
|
446
|
+
nbr_indices = nbr_indices[:, 1:] # (n_spots, k)
|
|
447
|
+
|
|
448
|
+
# Neighborhood mean profile: for each spot, mean expression over k neighbors
|
|
449
|
+
nbr_profiles = np.zeros((n_spots, len(feature_genes)), dtype=float)
|
|
450
|
+
for i in range(n_spots):
|
|
451
|
+
nbr_profiles[i] = expr[nbr_indices[i]].mean(axis=0)
|
|
452
|
+
|
|
453
|
+
# K-means clustering on neighborhood profiles
|
|
454
|
+
n_niches_actual = min(n_niches, n_spots)
|
|
455
|
+
kmeans = KMeans(n_clusters=n_niches_actual, random_state=random_state, n_init=10)
|
|
456
|
+
niche_labels = kmeans.fit_predict(nbr_profiles)
|
|
457
|
+
|
|
458
|
+
# Store in adata.obs
|
|
459
|
+
self.adata.obs['spatial_niche'] = niche_labels.astype(str)
|
|
460
|
+
|
|
461
|
+
# Characterize each niche
|
|
462
|
+
n_cp = len(self.cp_genes)
|
|
463
|
+
cp_feature_indices = list(range(n_cp))
|
|
464
|
+
immune_feature_indices = list(range(n_cp, len(feature_genes)))
|
|
465
|
+
immune_feature_names = feature_genes[n_cp:]
|
|
466
|
+
|
|
467
|
+
has_region = 'region_type' in self.adata.obs.columns
|
|
468
|
+
region_labels = self.adata.obs['region_type'].values if has_region else None
|
|
469
|
+
|
|
470
|
+
records = []
|
|
471
|
+
for niche_id in range(n_niches_actual):
|
|
472
|
+
niche_mask = niche_labels == niche_id
|
|
473
|
+
n_spots_niche = int(niche_mask.sum())
|
|
474
|
+
if n_spots_niche == 0:
|
|
475
|
+
continue
|
|
476
|
+
|
|
477
|
+
niche_expr = expr[niche_mask] # (n_niche_spots, n_features)
|
|
478
|
+
|
|
479
|
+
# Top 3 checkpoint genes by mean expression in niche
|
|
480
|
+
if self.cp_genes:
|
|
481
|
+
cp_means = niche_expr[:, cp_feature_indices].mean(axis=0)
|
|
482
|
+
top_cp_idx = np.argsort(cp_means)[::-1][:3]
|
|
483
|
+
top_checkpoints = ', '.join([self.cp_genes[i] for i in top_cp_idx])
|
|
484
|
+
else:
|
|
485
|
+
top_checkpoints = ''
|
|
486
|
+
|
|
487
|
+
# Top 3 immune marker genes by mean expression in niche
|
|
488
|
+
if immune_feature_indices and immune_feature_names:
|
|
489
|
+
immune_means = niche_expr[:, immune_feature_indices].mean(axis=0)
|
|
490
|
+
top_imm_idx = np.argsort(immune_means)[::-1][:3]
|
|
491
|
+
top_immune = ', '.join([immune_feature_names[i] for i in top_imm_idx])
|
|
492
|
+
else:
|
|
493
|
+
top_immune = ''
|
|
494
|
+
|
|
495
|
+
# Region composition
|
|
496
|
+
pct_tumor_core = 0.0
|
|
497
|
+
pct_stroma = 0.0
|
|
498
|
+
pct_immune_enriched = 0.0
|
|
499
|
+
if has_region and region_labels is not None:
|
|
500
|
+
niche_regions = region_labels[niche_mask]
|
|
501
|
+
pct_tumor_core = float(np.mean(niche_regions == 'tumor_core'))
|
|
502
|
+
pct_stroma = float(np.mean(niche_regions == 'stroma'))
|
|
503
|
+
pct_immune_enriched = float(np.mean(niche_regions == 'immune_enriched'))
|
|
504
|
+
|
|
505
|
+
records.append({
|
|
506
|
+
'niche_id': niche_id,
|
|
507
|
+
'top_checkpoints': top_checkpoints,
|
|
508
|
+
'top_immune': top_immune,
|
|
509
|
+
'n_spots': n_spots_niche,
|
|
510
|
+
'pct_tumor_core': pct_tumor_core,
|
|
511
|
+
'pct_stroma': pct_stroma,
|
|
512
|
+
'pct_immune_enriched': pct_immune_enriched,
|
|
513
|
+
})
|
|
514
|
+
|
|
515
|
+
df = pd.DataFrame(records).set_index('niche_id')
|
|
516
|
+
return df
|