spatialcheckpoint 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. spatialcheckpoint/__init__.py +39 -0
  2. spatialcheckpoint/analysis/__init__.py +1 -0
  3. spatialcheckpoint/analysis/colocalization.py +516 -0
  4. spatialcheckpoint/analysis/domain_annotation.py +472 -0
  5. spatialcheckpoint/analysis/gradient.py +211 -0
  6. spatialcheckpoint/analysis/spatial_expression.py +320 -0
  7. spatialcheckpoint/analysis/spatial_features.py +673 -0
  8. spatialcheckpoint/cli.py +270 -0
  9. spatialcheckpoint/configs/checkpoint_panel.yaml +80 -0
  10. spatialcheckpoint/configs/spatial_datasets.yaml +204 -0
  11. spatialcheckpoint/data/__init__.py +1 -0
  12. spatialcheckpoint/data/download.py +702 -0
  13. spatialcheckpoint/data/loader.py +199 -0
  14. spatialcheckpoint/data/preprocess.py +506 -0
  15. spatialcheckpoint/model/__init__.py +1 -0
  16. spatialcheckpoint/model/archetype_discovery.py +545 -0
  17. spatialcheckpoint/model/classifier.py +726 -0
  18. spatialcheckpoint/model/explainer.py +136 -0
  19. spatialcheckpoint/model/trainer.py +136 -0
  20. spatialcheckpoint/utils/__init__.py +1 -0
  21. spatialcheckpoint/utils/gene_sets.py +171 -0
  22. spatialcheckpoint/utils/metrics.py +233 -0
  23. spatialcheckpoint/validation/__init__.py +1 -0
  24. spatialcheckpoint/validation/bulk_mapping.py +90 -0
  25. spatialcheckpoint/validation/clinical_association.py +825 -0
  26. spatialcheckpoint/visualization/__init__.py +1 -0
  27. spatialcheckpoint/visualization/paper_figures.py +1489 -0
  28. spatialcheckpoint/visualization/spatial_plots.py +227 -0
  29. spatialcheckpoint-0.1.0.dist-info/METADATA +462 -0
  30. spatialcheckpoint-0.1.0.dist-info/RECORD +33 -0
  31. spatialcheckpoint-0.1.0.dist-info/WHEEL +5 -0
  32. spatialcheckpoint-0.1.0.dist-info/entry_points.txt +2 -0
  33. spatialcheckpoint-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,39 @@
1
+ """SpatialCheckpoint: Spatial heterogeneity profiling of immune checkpoints."""
2
+ __version__ = "0.1.0"
3
+
4
+ from spatialcheckpoint.data.preprocess import SpatialDataPreprocessor
5
+ from spatialcheckpoint.data.loader import SpatialDataLoader
6
+ from spatialcheckpoint.analysis.spatial_expression import SpatialCheckpointProfiler
7
+ from spatialcheckpoint.analysis.spatial_features import SpatialFeatureEngineer
8
+ from spatialcheckpoint.analysis.colocalization import CheckpointColocalizationAnalyzer
9
+ from spatialcheckpoint.model.archetype_discovery import SpatialArchetypeDiscovery
10
+ from spatialcheckpoint.model.classifier import SpatialArchetypeClassifier
11
+ from spatialcheckpoint.model.trainer import ArchetypeModelTrainer
12
+ from spatialcheckpoint.model.explainer import ArchetypeExplainer
13
+ from spatialcheckpoint.utils.gene_sets import (
14
+ get_all_checkpoint_genes,
15
+ get_category_genes,
16
+ get_immune_cell_markers,
17
+ get_ligand_receptor_pairs,
18
+ )
19
+
20
+ __all__ = [
21
+ "__version__",
22
+ # Data
23
+ "SpatialDataPreprocessor",
24
+ "SpatialDataLoader",
25
+ # Analysis
26
+ "SpatialCheckpointProfiler",
27
+ "SpatialFeatureEngineer",
28
+ "CheckpointColocalizationAnalyzer",
29
+ # Model
30
+ "SpatialArchetypeDiscovery",
31
+ "SpatialArchetypeClassifier",
32
+ "ArchetypeModelTrainer",
33
+ "ArchetypeExplainer",
34
+ # Gene sets
35
+ "get_all_checkpoint_genes",
36
+ "get_category_genes",
37
+ "get_immune_cell_markers",
38
+ "get_ligand_receptor_pairs",
39
+ ]
@@ -0,0 +1 @@
1
+ """Analysis module."""
@@ -0,0 +1,516 @@
1
+ """Spatial co-localization analysis of immune checkpoints."""
2
+ from __future__ import annotations
3
+
4
+ from typing import List, Optional
5
+ import warnings
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import scipy.sparse as sp
10
+ from scipy import stats
11
+ from scipy.spatial import cKDTree
12
+ import scanpy as sc
13
+
14
+ try:
15
+ import squidpy as sq
16
+ HAS_SQUIDPY = True
17
+ except ImportError:
18
+ HAS_SQUIDPY = False
19
+
20
+
21
+ class CheckpointColocalizationAnalyzer:
22
+ """Spatial co-localization analysis for immune checkpoints in Visium data."""
23
+
24
+ # Known ligand-receptor pairs for checkpoints
25
+ LR_PAIRS = [
26
+ ('CD274', 'PDCD1'), # PD-L1 / PD-1
27
+ ('PDCD1LG2', 'PDCD1'), # PD-L2 / PD-1
28
+ ('LGALS9', 'HAVCR2'), # Galectin-9 / TIM-3
29
+ ('PVR', 'TIGIT'), # CD155 / TIGIT
30
+ ('NECTIN2', 'TIGIT'), # CD112 / TIGIT
31
+ ('FGL1', 'LAG3'), # FGL1 / LAG-3
32
+ ('CEACAM1', 'LAG3'), # CEACAM1 / LAG-3
33
+ ('CD47', 'SIRPA'), # CD47 / SIRPα
34
+ ]
35
+
36
+ def __init__(self, adata: sc.AnnData, checkpoint_genes: List[str],
37
+ immune_markers: Optional[dict] = None):
38
+ self.adata = adata
39
+ self.cp_genes = [g for g in checkpoint_genes if g in adata.var_names]
40
+ self.immune_markers = immune_markers or {}
41
+ self._coords = None
42
+
43
+ # ------------------------------------------------------------------
44
+ # Helpers
45
+ # ------------------------------------------------------------------
46
+
47
+ def _get_coords(self) -> np.ndarray:
48
+ """Get spot spatial coordinates (n_spots, 2)."""
49
+ if self._coords is None:
50
+ self._coords = np.array(self.adata.obsm['spatial'], dtype=float)
51
+ return self._coords
52
+
53
+ def _get_expr(self, genes: list[str]) -> np.ndarray:
54
+ """Get dense expression for given genes (n_spots, n_genes).
55
+ Handles sparse adata.X."""
56
+ indices = [self.adata.var_names.get_loc(g) for g in genes]
57
+ X = self.adata.X[:, indices]
58
+ if sp.issparse(X):
59
+ X = X.toarray()
60
+ return np.array(X, dtype=float)
61
+
62
+ # ------------------------------------------------------------------
63
+ # Public methods
64
+ # ------------------------------------------------------------------
65
+
66
+ def pairwise_cooccurrence(self, interval: int = 50,
67
+ n_intervals: int = 10) -> dict:
68
+ """Pairwise checkpoint spatial co-occurrence across distance intervals.
69
+
70
+ For each pair of checkpoint genes (g1, g2):
71
+ - Define high-expression spots: top 25% expression
72
+ - For distances d = 0, interval, 2*interval, ... n_intervals*interval:
73
+ compute fraction of g2-high spots within distance d of g1-high spots
74
+ vs random expectation
75
+ - co_occurrence_ratio[d] = observed / expected
76
+
77
+ If squidpy available, use sq.gr.co_occurrence instead.
78
+
79
+ Returns: dict mapping (gene1, gene2) -> np.ndarray of co-occurrence ratios
80
+ Only compute for the top-20 most variable checkpoint gene pairs
81
+ (to keep computation tractable).
82
+ """
83
+ if len(self.cp_genes) < 2:
84
+ return {}
85
+
86
+ if HAS_SQUIDPY:
87
+ return self._pairwise_cooccurrence_squidpy(interval, n_intervals)
88
+ return self._pairwise_cooccurrence_fallback(interval, n_intervals)
89
+
90
+ def _pairwise_cooccurrence_squidpy(self, interval: int, n_intervals: int) -> dict:
91
+ """Co-occurrence via squidpy."""
92
+ top_genes = self._select_top_variable_genes(n=20)
93
+ if len(top_genes) < 2:
94
+ return {}
95
+
96
+ # Build interval array: squidpy expects an array of distances
97
+ distances = np.arange(0, (n_intervals + 1) * interval, interval, dtype=float)
98
+ distances[0] = 1e-6 # avoid zero
99
+
100
+ adata_sub = self.adata[:, top_genes].copy()
101
+ try:
102
+ sq.gr.co_occurrence(adata_sub, cluster_key=None, genes=top_genes,
103
+ spatial_key='spatial', interval=distances)
104
+ result = {}
105
+ co_occ = adata_sub.uns.get('co_occurrence', {})
106
+ for g1 in top_genes:
107
+ for g2 in top_genes:
108
+ if g1 == g2:
109
+ continue
110
+ key = (g1, g2)
111
+ if key in co_occ:
112
+ result[key] = np.array(co_occ[key], dtype=float)
113
+ return result
114
+ except Exception:
115
+ warnings.warn(
116
+ "squidpy co_occurrence failed; falling back to manual computation.",
117
+ RuntimeWarning,
118
+ )
119
+ return self._pairwise_cooccurrence_fallback(interval, n_intervals)
120
+
121
+ def _pairwise_cooccurrence_fallback(self, interval: int, n_intervals: int) -> dict:
122
+ """Manual co-occurrence computation."""
123
+ top_genes = self._select_top_variable_genes(n=20)
124
+ if len(top_genes) < 2:
125
+ return {}
126
+
127
+ coords = self._get_coords()
128
+ expr = self._get_expr(top_genes) # (n_spots, n_top)
129
+ gene_index = {g: i for i, g in enumerate(top_genes)}
130
+ n_spots = coords.shape[0]
131
+
132
+ distances = np.arange(0, (n_intervals + 1) * interval, interval, dtype=float)
133
+
134
+ # Build a KDTree for all spots
135
+ tree = cKDTree(coords)
136
+
137
+ # Global fraction of high spots for each gene (baseline / expected)
138
+ threshold_75 = np.percentile(expr, 75, axis=0) # (n_top,)
139
+ high_masks = expr > threshold_75[np.newaxis, :] # (n_spots, n_top)
140
+ global_frac = high_masks.mean(axis=0) # (n_top,)
141
+
142
+ result = {}
143
+ for i, g1 in enumerate(top_genes):
144
+ g1_high = np.where(high_masks[:, i])[0]
145
+ if len(g1_high) == 0:
146
+ continue
147
+ for j, g2 in enumerate(top_genes):
148
+ if i == j:
149
+ continue
150
+ g2_high_set = set(np.where(high_masks[:, j])[0])
151
+ ratios = np.zeros(len(distances))
152
+ for d_idx, d in enumerate(distances):
153
+ if d < 1e-6:
154
+ # Distance = 0: only the spot itself
155
+ observed = np.mean([1.0 if k in g2_high_set else 0.0
156
+ for k in g1_high])
157
+ else:
158
+ # Spots within distance d for each g1-high spot
159
+ neighbors_list = tree.query_ball_point(coords[g1_high], r=d)
160
+ obs_fracs = []
161
+ for spot_idx, nbrs in zip(g1_high, neighbors_list):
162
+ nbrs_excl = [n for n in nbrs if n != spot_idx]
163
+ if len(nbrs_excl) == 0:
164
+ obs_fracs.append(0.0)
165
+ else:
166
+ frac = sum(1 for n in nbrs_excl if n in g2_high_set) / len(nbrs_excl)
167
+ obs_fracs.append(frac)
168
+ observed = float(np.mean(obs_fracs)) if obs_fracs else 0.0
169
+
170
+ expected = global_frac[j]
171
+ ratios[d_idx] = observed / (expected + 1e-9)
172
+
173
+ result[(g1, g2)] = ratios
174
+
175
+ return result
176
+
177
+ def _select_top_variable_genes(self, n: int = 20) -> list:
178
+ """Select top-n checkpoint genes by expression variance."""
179
+ if len(self.cp_genes) <= n:
180
+ return list(self.cp_genes)
181
+ expr = self._get_expr(self.cp_genes)
182
+ variances = expr.var(axis=0)
183
+ top_idx = np.argsort(variances)[::-1][:n]
184
+ return [self.cp_genes[i] for i in top_idx]
185
+
186
+ # ------------------------------------------------------------------
187
+
188
+ def checkpoint_immune_proximity(self, n_permutations: int = 200,
189
+ top_pct: float = 0.25) -> pd.DataFrame:
190
+ """Distance from checkpoint-high spots to nearest immune cell spots.
191
+
192
+ For each checkpoint gene:
193
+ 1. Identify checkpoint-high spots (top top_pct expression)
194
+ 2. For each immune cell type (from immune_markers):
195
+ a. Identify immune spots: ≥1 immune marker expressed
196
+ b. Use KDTree to compute mean distance: checkpoint-high → nearest immune
197
+ c. Permutation test (n_permutations): shuffle checkpoint-high labels,
198
+ compute null distribution of mean distances
199
+ d. z_score = (observed - null_mean) / null_std
200
+ e. p_value = fraction of permutations with smaller distance
201
+
202
+ Returns DataFrame: columns=[gene, immune_cell, mean_dist, expected_dist,
203
+ z_score, p_value]
204
+ Negative z_score = closer than random (co-localized).
205
+ """
206
+ if not self.cp_genes or not self.immune_markers:
207
+ return pd.DataFrame(columns=['gene', 'immune_cell', 'mean_dist',
208
+ 'expected_dist', 'z_score', 'p_value'])
209
+
210
+ coords = self._get_coords()
211
+ n_spots = coords.shape[0]
212
+ rng = np.random.default_rng(42)
213
+
214
+ # Precompute checkpoint expressions
215
+ cp_expr = self._get_expr(self.cp_genes) # (n_spots, n_cp)
216
+ cp_thresholds = np.quantile(cp_expr, 1.0 - top_pct, axis=0) # (n_cp,)
217
+
218
+ # Precompute immune spot masks
219
+ immune_spot_masks = {}
220
+ for cell_type, markers in self.immune_markers.items():
221
+ avail = [m for m in markers if m in self.adata.var_names]
222
+ if not avail:
223
+ continue
224
+ marker_expr = self._get_expr(avail) # (n_spots, n_markers)
225
+ # A spot is "immune" if ≥1 marker expressed (> 0)
226
+ immune_mask = (marker_expr > 0).any(axis=1)
227
+ immune_spot_masks[cell_type] = immune_mask
228
+
229
+ records = []
230
+ for g_idx, gene in enumerate(self.cp_genes):
231
+ cp_high_mask = cp_expr[:, g_idx] > cp_thresholds[g_idx]
232
+ cp_high_spots = np.where(cp_high_mask)[0]
233
+ if len(cp_high_spots) == 0:
234
+ continue
235
+
236
+ for cell_type, immune_mask in immune_spot_masks.items():
237
+ immune_spots = np.where(immune_mask)[0]
238
+ if len(immune_spots) == 0:
239
+ continue
240
+
241
+ immune_tree = cKDTree(coords[immune_spots])
242
+
243
+ # Observed mean distance: checkpoint-high → nearest immune
244
+ dists, _ = immune_tree.query(coords[cp_high_spots], k=1)
245
+ observed_mean = float(np.mean(dists))
246
+
247
+ # Permutation test: shuffle which spots are "checkpoint-high"
248
+ n_cp_high = len(cp_high_spots)
249
+ null_means = np.empty(n_permutations)
250
+ for perm_i in range(n_permutations):
251
+ perm_indices = rng.choice(n_spots, size=n_cp_high, replace=False)
252
+ perm_dists, _ = immune_tree.query(coords[perm_indices], k=1)
253
+ null_means[perm_i] = float(np.mean(perm_dists))
254
+
255
+ null_mean = float(np.mean(null_means))
256
+ null_std = float(np.std(null_means))
257
+ z_score = (observed_mean - null_mean) / (null_std + 1e-9)
258
+ # p_value = fraction of permutations with smaller (more co-localized) distance
259
+ p_value = float(np.mean(null_means <= observed_mean))
260
+
261
+ records.append({
262
+ 'gene': gene,
263
+ 'immune_cell': cell_type,
264
+ 'mean_dist': observed_mean,
265
+ 'expected_dist': null_mean,
266
+ 'z_score': z_score,
267
+ 'p_value': p_value,
268
+ })
269
+
270
+ return pd.DataFrame(records, columns=['gene', 'immune_cell', 'mean_dist',
271
+ 'expected_dist', 'z_score', 'p_value'])
272
+
273
+ # ------------------------------------------------------------------
274
+
275
+ def ligand_receptor_spatial_analysis(self) -> pd.DataFrame:
276
+ """Spatial proximity analysis for known checkpoint LR pairs.
277
+
278
+ For each (ligand, receptor) pair that has both genes in dataset:
279
+ 1. Get ligand expression per spot
280
+ 2. For each ligand-high spot (top 25%), compute mean receptor expression
281
+ in k=10 nearest neighbors vs k=10 neighbors of random spots
282
+ 3. Log2 fold change: LR_score = log2(neighbor_receptor / non_neighbor_receptor + 1e-6)
283
+ 4. Permutation test for significance (n=500 permutations)
284
+
285
+ Returns DataFrame: columns=[ligand, receptor, alias, lr_score, pvalue,
286
+ ligand_in_dataset, receptor_in_dataset]
287
+ """
288
+ var_names_set = set(self.adata.var_names)
289
+ coords = self._get_coords()
290
+ n_spots = coords.shape[0]
291
+ k = min(10, n_spots - 1)
292
+ n_permutations = 500
293
+ rng = np.random.default_rng(42)
294
+
295
+ # Build alias map for LR pairs
296
+ alias_map = {
297
+ ('CD274', 'PDCD1'): 'PD-L1/PD-1',
298
+ ('PDCD1LG2', 'PDCD1'): 'PD-L2/PD-1',
299
+ ('LGALS9', 'HAVCR2'): 'Galectin-9/TIM-3',
300
+ ('PVR', 'TIGIT'): 'CD155/TIGIT',
301
+ ('NECTIN2', 'TIGIT'): 'CD112/TIGIT',
302
+ ('FGL1', 'LAG3'): 'FGL1/LAG-3',
303
+ ('CEACAM1', 'LAG3'): 'CEACAM1/LAG-3',
304
+ ('CD47', 'SIRPA'): 'CD47/SIRPα',
305
+ }
306
+
307
+ # Build KDTree once for all spots
308
+ tree = cKDTree(coords)
309
+
310
+ records = []
311
+ for ligand, receptor in self.LR_PAIRS:
312
+ lig_in = ligand in var_names_set
313
+ rec_in = receptor in var_names_set
314
+ alias = alias_map.get((ligand, receptor), f'{ligand}/{receptor}')
315
+
316
+ if not (lig_in and rec_in):
317
+ records.append({
318
+ 'ligand': ligand,
319
+ 'receptor': receptor,
320
+ 'alias': alias,
321
+ 'lr_score': np.nan,
322
+ 'pvalue': np.nan,
323
+ 'ligand_in_dataset': lig_in,
324
+ 'receptor_in_dataset': rec_in,
325
+ })
326
+ continue
327
+
328
+ lig_expr = self._get_expr([ligand])[:, 0] # (n_spots,)
329
+ rec_expr = self._get_expr([receptor])[:, 0] # (n_spots,)
330
+
331
+ # Ligand-high spots: top 25%
332
+ lig_threshold = np.percentile(lig_expr, 75)
333
+ lig_high = np.where(lig_expr > lig_threshold)[0]
334
+ if len(lig_high) == 0:
335
+ records.append({
336
+ 'ligand': ligand, 'receptor': receptor, 'alias': alias,
337
+ 'lr_score': np.nan, 'pvalue': np.nan,
338
+ 'ligand_in_dataset': lig_in, 'receptor_in_dataset': rec_in,
339
+ })
340
+ continue
341
+
342
+ # For each ligand-high spot, find k nearest neighbors
343
+ _, nbr_indices = tree.query(coords[lig_high], k=k + 1)
344
+ # Exclude the query spot itself (first column will be the spot itself if it's in tree)
345
+ # tree.query returns the spot and its neighbors; exclude self
346
+ nbr_indices = nbr_indices[:, 1:] # (n_lig_high, k)
347
+
348
+ # Mean receptor expression in neighbors of ligand-high spots
349
+ neighbor_rec = rec_expr[nbr_indices].mean() # scalar
350
+
351
+ # Mean receptor expression in neighbors of random (non-ligand-high) spots
352
+ non_lig_high = np.where(lig_expr <= lig_threshold)[0]
353
+ if len(non_lig_high) >= len(lig_high):
354
+ random_spots = rng.choice(non_lig_high, size=len(lig_high), replace=False)
355
+ else:
356
+ random_spots = rng.choice(n_spots, size=len(lig_high), replace=False)
357
+
358
+ _, rand_nbr_indices = tree.query(coords[random_spots], k=k + 1)
359
+ rand_nbr_indices = rand_nbr_indices[:, 1:]
360
+ random_rec = rec_expr[rand_nbr_indices].mean()
361
+
362
+ # LR score: log2 fold change
363
+ lr_score = float(np.log2((neighbor_rec + 1e-6) / (random_rec + 1e-6)))
364
+
365
+ # Permutation test: shuffle ligand-high labels
366
+ null_scores = np.empty(n_permutations)
367
+ for perm_i in range(n_permutations):
368
+ perm_lig_high = rng.choice(n_spots, size=len(lig_high), replace=False)
369
+ _, perm_nbr_idx = tree.query(coords[perm_lig_high], k=k + 1)
370
+ perm_nbr_idx = perm_nbr_idx[:, 1:]
371
+ perm_neighbor_rec = rec_expr[perm_nbr_idx].mean()
372
+
373
+ perm_rand = rng.choice(n_spots, size=len(lig_high), replace=False)
374
+ _, perm_rand_idx = tree.query(coords[perm_rand], k=k + 1)
375
+ perm_rand_idx = perm_rand_idx[:, 1:]
376
+ perm_random_rec = rec_expr[perm_rand_idx].mean()
377
+
378
+ null_scores[perm_i] = float(
379
+ np.log2((perm_neighbor_rec + 1e-6) / (perm_random_rec + 1e-6))
380
+ )
381
+
382
+ # Two-sided p-value: fraction of permutations with |score| >= |observed|
383
+ pvalue = float(np.mean(np.abs(null_scores) >= np.abs(lr_score)))
384
+
385
+ records.append({
386
+ 'ligand': ligand,
387
+ 'receptor': receptor,
388
+ 'alias': alias,
389
+ 'lr_score': lr_score,
390
+ 'pvalue': pvalue,
391
+ 'ligand_in_dataset': lig_in,
392
+ 'receptor_in_dataset': rec_in,
393
+ })
394
+
395
+ return pd.DataFrame(records, columns=['ligand', 'receptor', 'alias',
396
+ 'lr_score', 'pvalue',
397
+ 'ligand_in_dataset', 'receptor_in_dataset'])
398
+
399
+ # ------------------------------------------------------------------
400
+
401
+ def spatial_niche_analysis(self, n_niches: int = 5,
402
+ k_neighbors: int = 15,
403
+ random_state: int = 42) -> pd.DataFrame:
404
+ """Discover spatial niches based on local checkpoint + immune expression.
405
+
406
+ Algorithm:
407
+ 1. For each spot, compute neighborhood profile:
408
+ - k=15 nearest neighbor mean expression for all cp_genes + immune markers
409
+ 2. Cluster neighborhood profiles using k-means (n_niches clusters)
410
+ with fixed random_state for reproducibility
411
+ 3. Store niche labels in adata.obs['spatial_niche']
412
+ 4. Characterize each niche: top checkpoint and immune marker genes
413
+
414
+ Returns DataFrame: index=niche_id, columns:
415
+ - top_checkpoints: top 3 checkpoint genes (by mean expression in niche)
416
+ - top_immune: top 3 immune markers
417
+ - n_spots: number of spots in niche
418
+ - pct_tumor_core, pct_stroma, pct_immune_enriched: region composition
419
+ """
420
+ from sklearn.cluster import KMeans
421
+
422
+ coords = self._get_coords()
423
+ n_spots = coords.shape[0]
424
+ k = min(k_neighbors, n_spots - 1)
425
+
426
+ # Collect all genes: checkpoint + all immune marker genes
427
+ all_immune_genes = []
428
+ for markers in self.immune_markers.values():
429
+ all_immune_genes.extend(markers)
430
+ all_immune_genes = list(dict.fromkeys(all_immune_genes)) # unique, order-preserving
431
+ avail_immune = [g for g in all_immune_genes if g in self.adata.var_names]
432
+
433
+ feature_genes = list(self.cp_genes) + avail_immune
434
+ if not feature_genes:
435
+ warnings.warn("No feature genes available for niche analysis.", RuntimeWarning)
436
+ self.adata.obs['spatial_niche'] = 0
437
+ return pd.DataFrame()
438
+
439
+ expr = self._get_expr(feature_genes) # (n_spots, n_features)
440
+
441
+ # Build KDTree and compute neighborhood mean profiles
442
+ tree = cKDTree(coords)
443
+ # Query k+1 neighbors (includes the spot itself)
444
+ _, nbr_indices = tree.query(coords, k=k + 1) # (n_spots, k+1)
445
+ # Exclude self (first index is always the query point in cKDTree)
446
+ nbr_indices = nbr_indices[:, 1:] # (n_spots, k)
447
+
448
+ # Neighborhood mean profile: for each spot, mean expression over k neighbors
449
+ nbr_profiles = np.zeros((n_spots, len(feature_genes)), dtype=float)
450
+ for i in range(n_spots):
451
+ nbr_profiles[i] = expr[nbr_indices[i]].mean(axis=0)
452
+
453
+ # K-means clustering on neighborhood profiles
454
+ n_niches_actual = min(n_niches, n_spots)
455
+ kmeans = KMeans(n_clusters=n_niches_actual, random_state=random_state, n_init=10)
456
+ niche_labels = kmeans.fit_predict(nbr_profiles)
457
+
458
+ # Store in adata.obs
459
+ self.adata.obs['spatial_niche'] = niche_labels.astype(str)
460
+
461
+ # Characterize each niche
462
+ n_cp = len(self.cp_genes)
463
+ cp_feature_indices = list(range(n_cp))
464
+ immune_feature_indices = list(range(n_cp, len(feature_genes)))
465
+ immune_feature_names = feature_genes[n_cp:]
466
+
467
+ has_region = 'region_type' in self.adata.obs.columns
468
+ region_labels = self.adata.obs['region_type'].values if has_region else None
469
+
470
+ records = []
471
+ for niche_id in range(n_niches_actual):
472
+ niche_mask = niche_labels == niche_id
473
+ n_spots_niche = int(niche_mask.sum())
474
+ if n_spots_niche == 0:
475
+ continue
476
+
477
+ niche_expr = expr[niche_mask] # (n_niche_spots, n_features)
478
+
479
+ # Top 3 checkpoint genes by mean expression in niche
480
+ if self.cp_genes:
481
+ cp_means = niche_expr[:, cp_feature_indices].mean(axis=0)
482
+ top_cp_idx = np.argsort(cp_means)[::-1][:3]
483
+ top_checkpoints = ', '.join([self.cp_genes[i] for i in top_cp_idx])
484
+ else:
485
+ top_checkpoints = ''
486
+
487
+ # Top 3 immune marker genes by mean expression in niche
488
+ if immune_feature_indices and immune_feature_names:
489
+ immune_means = niche_expr[:, immune_feature_indices].mean(axis=0)
490
+ top_imm_idx = np.argsort(immune_means)[::-1][:3]
491
+ top_immune = ', '.join([immune_feature_names[i] for i in top_imm_idx])
492
+ else:
493
+ top_immune = ''
494
+
495
+ # Region composition
496
+ pct_tumor_core = 0.0
497
+ pct_stroma = 0.0
498
+ pct_immune_enriched = 0.0
499
+ if has_region and region_labels is not None:
500
+ niche_regions = region_labels[niche_mask]
501
+ pct_tumor_core = float(np.mean(niche_regions == 'tumor_core'))
502
+ pct_stroma = float(np.mean(niche_regions == 'stroma'))
503
+ pct_immune_enriched = float(np.mean(niche_regions == 'immune_enriched'))
504
+
505
+ records.append({
506
+ 'niche_id': niche_id,
507
+ 'top_checkpoints': top_checkpoints,
508
+ 'top_immune': top_immune,
509
+ 'n_spots': n_spots_niche,
510
+ 'pct_tumor_core': pct_tumor_core,
511
+ 'pct_stroma': pct_stroma,
512
+ 'pct_immune_enriched': pct_immune_enriched,
513
+ })
514
+
515
+ df = pd.DataFrame(records).set_index('niche_id')
516
+ return df