scCS-py 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scCS/__init__.py +154 -0
- scCS/bifurcation.py +226 -0
- scCS/drivers.py +237 -0
- scCS/embedding.py +621 -0
- scCS/enrichment.py +289 -0
- scCS/plot.py +1153 -0
- scCS/scores.py +752 -0
- scCS/trajectory.py +761 -0
- sccs_py-0.3.2.dist-info/METADATA +31 -0
- sccs_py-0.3.2.dist-info/RECORD +12 -0
- sccs_py-0.3.2.dist-info/WHEEL +5 -0
- sccs_py-0.3.2.dist-info/top_level.txt +1 -0
scCS/embedding.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
1
|
+
"""
|
|
2
|
+
embedding.py — Radial star embedding for scCS.
|
|
3
|
+
|
|
4
|
+
Constructs a custom 2D layout where:
|
|
5
|
+
- The bifurcation cluster (progenitor) sits at the origin (0, 0).
|
|
6
|
+
- Each terminal fate population occupies its own radial arm, evenly
|
|
7
|
+
spaced at 360/k degrees around the origin.
|
|
8
|
+
- Within each arm, cells are ordered along the radial axis by a
|
|
9
|
+
differentiation metric (pseudotime, CytoTRACE2, pathway score, etc.)
|
|
10
|
+
so that less-differentiated cells are close to the center and
|
|
11
|
+
more-differentiated cells are at the periphery.
|
|
12
|
+
- ONLY cells belonging to the bifurcation cluster or a terminal fate
|
|
13
|
+
are included. All other populations are excluded from the embedding.
|
|
14
|
+
|
|
15
|
+
The result is stored in adata_sub.obsm['X_sccs'] on the returned subset
|
|
16
|
+
AnnData, and looks like a star or sunburst when plotted — one arm per
|
|
17
|
+
fate, radiating from the progenitor.
|
|
18
|
+
|
|
19
|
+
Velocity projection
|
|
20
|
+
-------------------
|
|
21
|
+
RNA velocity vectors (from scVelo) are projected into this custom 2D
|
|
22
|
+
space by computing the transition-probability-weighted displacement of
|
|
23
|
+
each cell in the scCS coordinate system.
|
|
24
|
+
|
|
25
|
+
Differentiation metrics supported
|
|
26
|
+
----------------------------------
|
|
27
|
+
- 'pseudotime' : scVelo velocity_pseudotime (default)
|
|
28
|
+
- 'cytotrace' : CytoTRACE2 score (column in adata.obs)
|
|
29
|
+
- 'custom' : any per-cell numeric column in adata.obs
|
|
30
|
+
- np.ndarray : directly supplied per-cell scores (shape n_cells,)
|
|
31
|
+
|
|
32
|
+
In all cases, higher score = more differentiated = farther from center.
|
|
33
|
+
If the metric is inverted (e.g., CytoTRACE2 where high = less
|
|
34
|
+
differentiated), pass invert_metric=True.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
import warnings
|
|
40
|
+
from typing import List, Optional, Tuple, Union
|
|
41
|
+
|
|
42
|
+
import numpy as np
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
import scvelo as scv
|
|
46
|
+
_SCVELO_AVAILABLE = True
|
|
47
|
+
except ImportError:
|
|
48
|
+
_SCVELO_AVAILABLE = False
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
import scanpy as sc
|
|
52
|
+
_SCANPY_AVAILABLE = True
|
|
53
|
+
except ImportError:
|
|
54
|
+
_SCANPY_AVAILABLE = False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ---------------------------------------------------------------------------
|
|
58
|
+
# Public API
|
|
59
|
+
# ---------------------------------------------------------------------------
|
|
60
|
+
|
|
61
|
+
def build_star_embedding(
|
|
62
|
+
adata,
|
|
63
|
+
bifurcation_cluster: str,
|
|
64
|
+
terminal_cell_types: List[str],
|
|
65
|
+
cluster_key: str = "leiden",
|
|
66
|
+
differentiation_metric: Union[str, np.ndarray] = "pseudotime",
|
|
67
|
+
invert_metric: bool = False,
|
|
68
|
+
arm_scale: float = 10.0,
|
|
69
|
+
jitter: float = 0.3,
|
|
70
|
+
seed: int = 42,
|
|
71
|
+
) -> "anndata.AnnData":
|
|
72
|
+
"""Build the radial star embedding on a subset of adata.
|
|
73
|
+
|
|
74
|
+
Only cells belonging to the bifurcation cluster or a terminal fate
|
|
75
|
+
cluster are included. All other populations are excluded entirely.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
adata : AnnData
|
|
80
|
+
Full dataset. Will NOT be modified.
|
|
81
|
+
bifurcation_cluster : str
|
|
82
|
+
Label of the progenitor/bifurcation cluster in adata.obs[cluster_key].
|
|
83
|
+
These cells are placed at the origin.
|
|
84
|
+
terminal_cell_types : list of str
|
|
85
|
+
Labels of the k terminal fate populations. Each gets one radial arm.
|
|
86
|
+
cluster_key : str
|
|
87
|
+
Column in adata.obs with cluster labels.
|
|
88
|
+
differentiation_metric : str or np.ndarray
|
|
89
|
+
How to order cells along each arm:
|
|
90
|
+
- 'pseudotime' : uses adata.obs['velocity_pseudotime'] (computed if absent)
|
|
91
|
+
- 'cytotrace' : uses adata.obs['cytotrace2_score'] (must be pre-computed)
|
|
92
|
+
- any str : uses adata.obs[differentiation_metric] directly
|
|
93
|
+
- np.ndarray : per-cell scores, shape (n_cells,) for the FULL adata
|
|
94
|
+
Higher value = more differentiated = farther from center.
|
|
95
|
+
invert_metric : bool
|
|
96
|
+
If True, invert the metric so that high values map to the center
|
|
97
|
+
(use for metrics where high = less differentiated, e.g. raw CytoTRACE2).
|
|
98
|
+
arm_scale : float
|
|
99
|
+
Maximum radial distance (length of each arm).
|
|
100
|
+
jitter : float
|
|
101
|
+
Gaussian noise added perpendicular to each arm to avoid overplotting.
|
|
102
|
+
seed : int
|
|
103
|
+
Random seed for jitter.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
adata_sub : AnnData
|
|
108
|
+
Subset containing ONLY bifurcation + terminal fate cells.
|
|
109
|
+
Star embedding stored in adata_sub.obsm['X_sccs'].
|
|
110
|
+
Metadata stored in adata_sub.uns['sccs'].
|
|
111
|
+
"""
|
|
112
|
+
import anndata
|
|
113
|
+
|
|
114
|
+
rng = np.random.default_rng(seed)
|
|
115
|
+
obs_labels_full = adata.obs[cluster_key].astype(str).values
|
|
116
|
+
|
|
117
|
+
# --- 0. Subset to relevant cells only ---
|
|
118
|
+
keep_labels = set([str(bifurcation_cluster)] + [str(f) for f in terminal_cell_types])
|
|
119
|
+
keep_mask = np.array([l in keep_labels for l in obs_labels_full])
|
|
120
|
+
|
|
121
|
+
if keep_mask.sum() == 0:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"No cells found matching bifurcation_cluster='{bifurcation_cluster}' "
|
|
124
|
+
f"or terminal_cell_types={terminal_cell_types} in "
|
|
125
|
+
f"adata.obs['{cluster_key}']."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# --- Resolve differentiation metric on the FULL adata BEFORE subsetting ---
|
|
129
|
+
# This is critical for 'pseudotime': scVelo's velocity_pseudotime computation
|
|
130
|
+
# requires the intact neighbor/velocity graph, which breaks after subsetting.
|
|
131
|
+
# We resolve the metric on the full object, then slice to keep_mask.
|
|
132
|
+
metric_for_sub: np.ndarray # will always be a pre-resolved array after this block
|
|
133
|
+
|
|
134
|
+
if isinstance(differentiation_metric, np.ndarray):
|
|
135
|
+
arr = np.asarray(differentiation_metric, dtype=float).ravel()
|
|
136
|
+
if len(arr) != adata.n_obs:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Custom metric array has length {len(arr)}, "
|
|
139
|
+
f"expected {adata.n_obs} (full adata)."
|
|
140
|
+
)
|
|
141
|
+
metric_for_sub = arr[keep_mask]
|
|
142
|
+
else:
|
|
143
|
+
# Resolve on full adata (graph intact), then slice
|
|
144
|
+
scores_full = _resolve_metric(adata, differentiation_metric, invert_metric)
|
|
145
|
+
metric_for_sub = scores_full[keep_mask]
|
|
146
|
+
|
|
147
|
+
adata_sub = adata[keep_mask].copy()
|
|
148
|
+
obs_labels = adata_sub.obs[cluster_key].astype(str).values
|
|
149
|
+
n_cells = adata_sub.n_obs
|
|
150
|
+
|
|
151
|
+
print(f"[scCS] Subsetting: {keep_mask.sum()} / {adata.n_obs} cells kept")
|
|
152
|
+
print(f" ({adata.n_obs - keep_mask.sum()} cells from other populations excluded)")
|
|
153
|
+
for lbl in sorted(keep_labels):
|
|
154
|
+
n = (obs_labels == lbl).sum()
|
|
155
|
+
role = "progenitor" if lbl == str(bifurcation_cluster) else "fate"
|
|
156
|
+
print(f" {lbl}: {n} cells ({role})")
|
|
157
|
+
|
|
158
|
+
# --- 1. Use the pre-resolved metric (already sliced to subset) ---
|
|
159
|
+
# metric_for_sub is always a np.ndarray at this point (resolved above).
|
|
160
|
+
# _fill_nan handles any remaining NaNs; inversion was already applied.
|
|
161
|
+
scores = _fill_nan(np.asarray(metric_for_sub, dtype=float).ravel())
|
|
162
|
+
|
|
163
|
+
# --- 2. Compute arm directions (evenly spaced angles) ---
|
|
164
|
+
k = len(terminal_cell_types)
|
|
165
|
+
arm_angles_deg = np.linspace(0.0, 360.0, k, endpoint=False)
|
|
166
|
+
arm_angles_rad = np.radians(arm_angles_deg)
|
|
167
|
+
arm_dirs = np.stack([np.cos(arm_angles_rad), np.sin(arm_angles_rad)], axis=1) # (k, 2)
|
|
168
|
+
|
|
169
|
+
# --- 3. Assign each cell to an arm ---
|
|
170
|
+
# Bifurcation cells -> arm index -1 (origin)
|
|
171
|
+
# Terminal fate cells -> their arm index
|
|
172
|
+
arm_assignment = np.full(n_cells, -1, dtype=int)
|
|
173
|
+
for j, fate in enumerate(terminal_cell_types):
|
|
174
|
+
mask = obs_labels == str(fate)
|
|
175
|
+
arm_assignment[mask] = j
|
|
176
|
+
|
|
177
|
+
# --- 4. Compute per-arm score ranges for normalization ---
|
|
178
|
+
bif_mask_sub = obs_labels == str(bifurcation_cluster)
|
|
179
|
+
arm_score_ranges = []
|
|
180
|
+
for j, fate in enumerate(terminal_cell_types):
|
|
181
|
+
fate_mask = obs_labels == str(fate)
|
|
182
|
+
combined_mask = fate_mask | bif_mask_sub
|
|
183
|
+
if combined_mask.sum() > 0:
|
|
184
|
+
s = scores[combined_mask]
|
|
185
|
+
arm_score_ranges.append((s.min(), s.max()))
|
|
186
|
+
else:
|
|
187
|
+
arm_score_ranges.append((scores.min(), scores.max()))
|
|
188
|
+
|
|
189
|
+
# --- 5. Place cells in 2D ---
|
|
190
|
+
coords = np.zeros((n_cells, 2), dtype=float)
|
|
191
|
+
|
|
192
|
+
# Bifurcation cluster: cluster at origin with small jitter
|
|
193
|
+
n_bif = bif_mask_sub.sum()
|
|
194
|
+
if n_bif > 0:
|
|
195
|
+
coords[bif_mask_sub] = rng.normal(0.0, jitter * 0.5, size=(n_bif, 2))
|
|
196
|
+
|
|
197
|
+
# Fate cells: place along their assigned arm
|
|
198
|
+
for j in range(k):
|
|
199
|
+
cell_mask = arm_assignment == j
|
|
200
|
+
if cell_mask.sum() == 0:
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
s_min, s_max = arm_score_ranges[j]
|
|
204
|
+
if s_max <= s_min:
|
|
205
|
+
r = np.linspace(0.0, arm_scale, cell_mask.sum())
|
|
206
|
+
else:
|
|
207
|
+
cell_scores_arm = scores[cell_mask]
|
|
208
|
+
r = (cell_scores_arm - s_min) / (s_max - s_min) * arm_scale
|
|
209
|
+
r = np.clip(r, 0.0, arm_scale)
|
|
210
|
+
|
|
211
|
+
arm_dir = arm_dirs[j]
|
|
212
|
+
positions = np.outer(r, arm_dir)
|
|
213
|
+
perp_dir = np.array([-arm_dir[1], arm_dir[0]])
|
|
214
|
+
perp_noise = rng.normal(0.0, jitter, size=cell_mask.sum())
|
|
215
|
+
positions += np.outer(perp_noise, perp_dir)
|
|
216
|
+
coords[cell_mask] = positions
|
|
217
|
+
|
|
218
|
+
# --- 6. Store in subset adata ---
|
|
219
|
+
adata_sub.obsm["X_sccs"] = coords
|
|
220
|
+
|
|
221
|
+
if "sccs" not in adata_sub.uns:
|
|
222
|
+
adata_sub.uns["sccs"] = {}
|
|
223
|
+
adata_sub.uns["sccs"]["arm_angles_deg"] = arm_angles_deg
|
|
224
|
+
adata_sub.uns["sccs"]["arm_dirs"] = arm_dirs
|
|
225
|
+
adata_sub.uns["sccs"]["arm_scale"] = arm_scale
|
|
226
|
+
adata_sub.uns["sccs"]["fate_names"] = [str(f) for f in terminal_cell_types]
|
|
227
|
+
adata_sub.uns["sccs"]["bifurcation_cluster"] = str(bifurcation_cluster)
|
|
228
|
+
adata_sub.uns["sccs"]["cluster_key"] = cluster_key
|
|
229
|
+
# Store integer indices of kept cells in the original adata (for velocity projection)
|
|
230
|
+
adata_sub.uns["sccs"]["parent_indices"] = np.where(keep_mask)[0]
|
|
231
|
+
|
|
232
|
+
adata_sub.obs["sccs_arm"] = arm_assignment
|
|
233
|
+
adata_sub.obs["sccs_arm_name"] = [
|
|
234
|
+
str(terminal_cell_types[a]) if a >= 0 else str(bifurcation_cluster)
|
|
235
|
+
for a in arm_assignment
|
|
236
|
+
]
|
|
237
|
+
|
|
238
|
+
print(
|
|
239
|
+
f'\n[scCS] Star embedding built → adata_sub.obsm["X_sccs"] shape: {coords.shape}'
|
|
240
|
+
)
|
|
241
|
+
print(
|
|
242
|
+
f' Arm angles: '
|
|
243
|
+
+ str({str(f): round(float(a), 1)
|
|
244
|
+
for f, a in zip(terminal_cell_types, arm_angles_deg)})
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return adata_sub
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def project_velocity_star(
|
|
251
|
+
adata_sub,
|
|
252
|
+
adata_full=None,
|
|
253
|
+
verbose: bool = True,
|
|
254
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
255
|
+
"""Project RNA velocity into the scCS star embedding space.
|
|
256
|
+
|
|
257
|
+
Uses the transition probability matrix from the full (unsubsetted) adata
|
|
258
|
+
to compute the expected displacement of each subset cell in the X_sccs
|
|
259
|
+
coordinate system.
|
|
260
|
+
|
|
261
|
+
This is necessary because subsetting breaks the velocity/neighbor graph
|
|
262
|
+
matrices (they retain full-dataset dimensions). We always use the full
|
|
263
|
+
graph and restrict to subset cell indices.
|
|
264
|
+
|
|
265
|
+
Parameters
|
|
266
|
+
----------
|
|
267
|
+
adata_sub : AnnData
|
|
268
|
+
Subset returned by build_star_embedding(). Must have X_sccs in obsm
|
|
269
|
+
and a 'sccs_parent_indices' entry in uns (set automatically).
|
|
270
|
+
adata_full : AnnData, optional
|
|
271
|
+
The original full dataset with intact velocity_graph in uns.
|
|
272
|
+
If None, falls back to using adata_sub directly (only works if
|
|
273
|
+
velocity_graph was computed on the subset).
|
|
274
|
+
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
vx, vy : np.ndarray, shape (n_sub_cells,)
|
|
278
|
+
Velocity components in the scCS embedding.
|
|
279
|
+
Also stored in adata_sub.obsm['velocity_sccs'].
|
|
280
|
+
"""
|
|
281
|
+
if "X_sccs" not in adata_sub.obsm:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
"X_sccs embedding not found. Run build_star_embedding() first."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
coords_sub = np.array(adata_sub.obsm["X_sccs"]) # (n_sub, 2)
|
|
287
|
+
n_sub = adata_sub.n_obs
|
|
288
|
+
|
|
289
|
+
# Retrieve the parent indices (positions in full adata) stored during subsetting
|
|
290
|
+
parent_idx = adata_sub.uns.get("sccs", {}).get("parent_indices", None)
|
|
291
|
+
|
|
292
|
+
# ── Strategy 1: scVelo velocity_embedding on the full adata ──────────────
|
|
293
|
+
# Run on full adata, then slice to subset rows. This is the most accurate.
|
|
294
|
+
if _SCVELO_AVAILABLE and adata_full is not None and "velocity_graph" in adata_full.uns:
|
|
295
|
+
if verbose:
|
|
296
|
+
print("[scCS] Projecting velocity via scVelo on full adata → slicing to subset...")
|
|
297
|
+
try:
|
|
298
|
+
# Temporarily inject X_sccs into full adata for all cells.
|
|
299
|
+
# Subset cells get their star coords; other cells get zeros (ignored after slicing).
|
|
300
|
+
n_full = adata_full.n_obs
|
|
301
|
+
coords_full = np.zeros((n_full, 2), dtype=float)
|
|
302
|
+
if parent_idx is not None:
|
|
303
|
+
coords_full[parent_idx] = coords_sub
|
|
304
|
+
else:
|
|
305
|
+
# Fallback: match by obs_names
|
|
306
|
+
sub_names = set(adata_sub.obs_names)
|
|
307
|
+
full_names = list(adata_full.obs_names)
|
|
308
|
+
idx_map = [i for i, n in enumerate(full_names) if n in sub_names]
|
|
309
|
+
coords_full[idx_map] = coords_sub
|
|
310
|
+
|
|
311
|
+
adata_full.obsm["X_sccs_tmp"] = coords_full
|
|
312
|
+
scv.tl.velocity_embedding(adata_full, basis="sccs_tmp")
|
|
313
|
+
V_full = np.array(adata_full.obsm["velocity_sccs_tmp"]) # (n_full, 2)
|
|
314
|
+
|
|
315
|
+
# Slice to subset
|
|
316
|
+
if parent_idx is not None:
|
|
317
|
+
V_sub = V_full[parent_idx]
|
|
318
|
+
else:
|
|
319
|
+
V_sub = V_full[idx_map]
|
|
320
|
+
|
|
321
|
+
vx, vy = V_sub[:, 0], V_sub[:, 1]
|
|
322
|
+
adata_sub.obsm["velocity_sccs"] = V_sub
|
|
323
|
+
|
|
324
|
+
# Clean up temporary keys
|
|
325
|
+
del adata_full.obsm["X_sccs_tmp"]
|
|
326
|
+
if "velocity_sccs_tmp" in adata_full.obsm:
|
|
327
|
+
del adata_full.obsm["velocity_sccs_tmp"]
|
|
328
|
+
|
|
329
|
+
if verbose:
|
|
330
|
+
print(f"[scCS] Velocity projected. Shape: {V_sub.shape}")
|
|
331
|
+
return vx, vy
|
|
332
|
+
|
|
333
|
+
except Exception as e:
|
|
334
|
+
warnings.warn(
|
|
335
|
+
f"scVelo velocity_embedding on full adata failed ({e}). "
|
|
336
|
+
"Falling back to graph-based projection.",
|
|
337
|
+
RuntimeWarning,
|
|
338
|
+
stacklevel=2,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# ── Strategy 2: graph-based projection using full adata's velocity_graph ──
|
|
342
|
+
# Manually compute T[sub, :][:, sub] × coords_sub - coords_sub
|
|
343
|
+
if adata_full is not None and "velocity_graph" in adata_full.uns:
|
|
344
|
+
if verbose:
|
|
345
|
+
print("[scCS] Using graph-based projection from full velocity_graph...")
|
|
346
|
+
try:
|
|
347
|
+
import scipy.sparse as sp
|
|
348
|
+
|
|
349
|
+
T_full = adata_full.uns["velocity_graph"]
|
|
350
|
+
if not sp.issparse(T_full):
|
|
351
|
+
T_full = sp.csr_matrix(T_full)
|
|
352
|
+
|
|
353
|
+
if parent_idx is None:
|
|
354
|
+
sub_names = set(adata_sub.obs_names)
|
|
355
|
+
full_names = list(adata_full.obs_names)
|
|
356
|
+
parent_idx = np.array([i for i, n in enumerate(full_names) if n in sub_names])
|
|
357
|
+
|
|
358
|
+
# Extract sub × sub block of the transition matrix
|
|
359
|
+
T_sub = T_full[parent_idx, :][:, parent_idx] # (n_sub, n_sub)
|
|
360
|
+
|
|
361
|
+
# Row-normalize
|
|
362
|
+
row_sums = np.array(T_sub.sum(axis=1)).ravel()
|
|
363
|
+
row_sums[row_sums == 0] = 1.0
|
|
364
|
+
T_norm = sp.diags(1.0 / row_sums) @ T_sub
|
|
365
|
+
|
|
366
|
+
expected = T_norm @ coords_sub # (n_sub, 2)
|
|
367
|
+
V_sub = expected - coords_sub
|
|
368
|
+
|
|
369
|
+
vx, vy = V_sub[:, 0], V_sub[:, 1]
|
|
370
|
+
adata_sub.obsm["velocity_sccs"] = V_sub
|
|
371
|
+
|
|
372
|
+
if verbose:
|
|
373
|
+
print(f"[scCS] Graph-based velocity projected. Shape: {V_sub.shape}")
|
|
374
|
+
return vx, vy
|
|
375
|
+
|
|
376
|
+
except Exception as e:
|
|
377
|
+
warnings.warn(
|
|
378
|
+
f"Graph-based projection from full adata failed ({e}). "
|
|
379
|
+
"Falling back to subset-only projection.",
|
|
380
|
+
RuntimeWarning,
|
|
381
|
+
stacklevel=2,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# ── Strategy 3: last resort — use whatever graph is in adata_sub ─────────
|
|
385
|
+
if verbose:
|
|
386
|
+
warnings.warn(
|
|
387
|
+
"No full adata provided and no compatible velocity_graph found. "
|
|
388
|
+
"Using subset-only graph (may have dimension issues). "
|
|
389
|
+
"Pass adata_full=adata to project_velocity() for best results.",
|
|
390
|
+
RuntimeWarning,
|
|
391
|
+
stacklevel=2,
|
|
392
|
+
)
|
|
393
|
+
vx, vy = _graph_velocity_projection(adata_sub, coords_sub, verbose=verbose)
|
|
394
|
+
adata_sub.obsm["velocity_sccs"] = np.stack([vx, vy], axis=1)
|
|
395
|
+
return vx, vy
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def run_velocity_pipeline(
|
|
399
|
+
adata,
|
|
400
|
+
mode: str = "dynamical",
|
|
401
|
+
n_top_genes: int = 2000,
|
|
402
|
+
n_pcs: int = 30,
|
|
403
|
+
n_neighbors: int = 30,
|
|
404
|
+
min_shared_counts: int = 20,
|
|
405
|
+
verbose: bool = True,
|
|
406
|
+
) -> None:
|
|
407
|
+
"""Run the full scVelo RNA velocity pipeline.
|
|
408
|
+
|
|
409
|
+
Requires spliced and unspliced count layers.
|
|
410
|
+
|
|
411
|
+
Parameters
|
|
412
|
+
----------
|
|
413
|
+
adata : AnnData
|
|
414
|
+
Must contain layers 'spliced' and 'unspliced'.
|
|
415
|
+
mode : {'dynamical', 'stochastic', 'steady_state'}
|
|
416
|
+
n_top_genes : int
|
|
417
|
+
n_pcs : int
|
|
418
|
+
n_neighbors : int
|
|
419
|
+
min_shared_counts : int
|
|
420
|
+
verbose : bool
|
|
421
|
+
"""
|
|
422
|
+
if not _SCVELO_AVAILABLE:
|
|
423
|
+
raise ImportError("scvelo is required. pip install scvelo")
|
|
424
|
+
|
|
425
|
+
missing = [l for l in ["spliced", "unspliced"] if l not in adata.layers]
|
|
426
|
+
if missing:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
f"Missing required layers: {missing}. "
|
|
429
|
+
"These are generated by velocyto, STARsolo, or alevin-fry."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
if verbose:
|
|
433
|
+
print(f"[scCS] Running scVelo pipeline (mode='{mode}')...")
|
|
434
|
+
|
|
435
|
+
scv.pp.filter_and_normalize(
|
|
436
|
+
adata, min_shared_counts=min_shared_counts,
|
|
437
|
+
n_top_genes=n_top_genes, log=True,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if "X_pca" not in adata.obsm and _SCANPY_AVAILABLE:
|
|
441
|
+
sc.tl.pca(adata, n_comps=n_pcs)
|
|
442
|
+
if "neighbors" not in adata.uns and _SCANPY_AVAILABLE:
|
|
443
|
+
sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs)
|
|
444
|
+
|
|
445
|
+
scv.pp.moments(adata, n_pcs=n_pcs, n_neighbors=n_neighbors)
|
|
446
|
+
|
|
447
|
+
if mode == "dynamical":
|
|
448
|
+
try:
|
|
449
|
+
scv.tl.recover_dynamics(adata, n_jobs=-1)
|
|
450
|
+
scv.tl.velocity(adata, mode="dynamical")
|
|
451
|
+
except Exception as e:
|
|
452
|
+
warnings.warn(
|
|
453
|
+
f"Dynamical model failed ({e}). Falling back to stochastic.",
|
|
454
|
+
RuntimeWarning, stacklevel=2,
|
|
455
|
+
)
|
|
456
|
+
scv.tl.velocity(adata, mode="stochastic")
|
|
457
|
+
else:
|
|
458
|
+
scv.tl.velocity(adata, mode=mode)
|
|
459
|
+
|
|
460
|
+
scv.tl.velocity_graph(adata)
|
|
461
|
+
|
|
462
|
+
try:
|
|
463
|
+
scv.tl.velocity_pseudotime(adata)
|
|
464
|
+
except Exception:
|
|
465
|
+
pass
|
|
466
|
+
|
|
467
|
+
if verbose:
|
|
468
|
+
print("[scCS] Velocity pipeline complete.")
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
# ---------------------------------------------------------------------------
|
|
472
|
+
# Internal helpers
|
|
473
|
+
# ---------------------------------------------------------------------------
|
|
474
|
+
|
|
475
|
+
def _resolve_metric(
|
|
476
|
+
adata,
|
|
477
|
+
metric: Union[str, np.ndarray],
|
|
478
|
+
invert: bool,
|
|
479
|
+
) -> np.ndarray:
|
|
480
|
+
"""Resolve differentiation metric to a per-cell float array."""
|
|
481
|
+
n_cells = adata.n_obs
|
|
482
|
+
|
|
483
|
+
if isinstance(metric, np.ndarray):
|
|
484
|
+
scores = np.asarray(metric, dtype=float).ravel()
|
|
485
|
+
if len(scores) != n_cells:
|
|
486
|
+
raise ValueError(
|
|
487
|
+
f"Custom metric array has length {len(scores)}, "
|
|
488
|
+
f"expected {n_cells}."
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
elif metric == "pseudotime":
|
|
492
|
+
if "velocity_pseudotime" not in adata.obs:
|
|
493
|
+
if _SCVELO_AVAILABLE and "velocity_graph" in adata.uns:
|
|
494
|
+
scv.tl.velocity_pseudotime(adata)
|
|
495
|
+
else:
|
|
496
|
+
warnings.warn(
|
|
497
|
+
"velocity_pseudotime not found and cannot be computed. "
|
|
498
|
+
"Falling back to uniform scores (random ordering).",
|
|
499
|
+
RuntimeWarning, stacklevel=3,
|
|
500
|
+
)
|
|
501
|
+
scores = np.random.default_rng(0).uniform(0, 1, n_cells)
|
|
502
|
+
if invert:
|
|
503
|
+
scores = 1.0 - scores
|
|
504
|
+
return _fill_nan(scores)
|
|
505
|
+
scores = np.array(adata.obs["velocity_pseudotime"], dtype=float)
|
|
506
|
+
|
|
507
|
+
elif metric == "cytotrace":
|
|
508
|
+
# CytoTRACE2: look for common column names
|
|
509
|
+
candidates = ["cytotrace2_score", "CytoTRACE2_Score", "cytotrace_score",
|
|
510
|
+
"CytoTRACE2", "cytotrace2"]
|
|
511
|
+
found = None
|
|
512
|
+
for c in candidates:
|
|
513
|
+
if c in adata.obs:
|
|
514
|
+
found = c
|
|
515
|
+
break
|
|
516
|
+
if found is None:
|
|
517
|
+
raise ValueError(
|
|
518
|
+
"CytoTRACE2 score not found in adata.obs. "
|
|
519
|
+
f"Expected one of: {candidates}. "
|
|
520
|
+
"Run CytoTRACE2 first or pass the column name as metric."
|
|
521
|
+
)
|
|
522
|
+
scores = np.array(adata.obs[found], dtype=float)
|
|
523
|
+
# CytoTRACE2: high score = stem-like = LESS differentiated
|
|
524
|
+
# So we invert by default unless user explicitly set invert=False
|
|
525
|
+
# We flip the invert flag here since CytoTRACE2 is naturally inverted
|
|
526
|
+
invert = not invert
|
|
527
|
+
|
|
528
|
+
else:
|
|
529
|
+
# Treat as column name in adata.obs
|
|
530
|
+
if metric not in adata.obs:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
f"Column '{metric}' not found in adata.obs. "
|
|
533
|
+
f"Available columns: {list(adata.obs.columns)}"
|
|
534
|
+
)
|
|
535
|
+
scores = np.array(adata.obs[metric], dtype=float)
|
|
536
|
+
|
|
537
|
+
scores = _fill_nan(scores)
|
|
538
|
+
|
|
539
|
+
if invert:
|
|
540
|
+
scores = scores.max() - scores
|
|
541
|
+
|
|
542
|
+
return scores
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def _fill_nan(scores: np.ndarray) -> np.ndarray:
|
|
546
|
+
"""Replace NaN values with the column median."""
|
|
547
|
+
nan_mask = np.isnan(scores)
|
|
548
|
+
if nan_mask.any():
|
|
549
|
+
median = np.nanmedian(scores)
|
|
550
|
+
scores = scores.copy()
|
|
551
|
+
scores[nan_mask] = median
|
|
552
|
+
return scores
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
def _graph_velocity_projection(
|
|
557
|
+
adata,
|
|
558
|
+
coords: np.ndarray,
|
|
559
|
+
verbose: bool = True,
|
|
560
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
561
|
+
"""Fallback: project velocity using the velocity graph transition matrix.
|
|
562
|
+
|
|
563
|
+
For each cell i, the velocity vector is the weighted average displacement
|
|
564
|
+
to its neighbors, weighted by the transition probability T[i, j]:
|
|
565
|
+
|
|
566
|
+
v_i = sum_j T[i,j] * (x_j - x_i)
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
adata : AnnData
|
|
571
|
+
coords : np.ndarray, shape (n_cells, 2)
|
|
572
|
+
verbose : bool
|
|
573
|
+
|
|
574
|
+
Returns
|
|
575
|
+
-------
|
|
576
|
+
vx, vy : np.ndarray, shape (n_cells,)
|
|
577
|
+
"""
|
|
578
|
+
import scipy.sparse as sp
|
|
579
|
+
|
|
580
|
+
if verbose:
|
|
581
|
+
print("[scCS] Using graph-based velocity projection...")
|
|
582
|
+
|
|
583
|
+
# Try velocity_graph first, then connectivities as fallback
|
|
584
|
+
T = None
|
|
585
|
+
for key in ["velocity_graph", "velocity_graph_neg"]:
|
|
586
|
+
if key in adata.uns:
|
|
587
|
+
T_raw = adata.uns[key]
|
|
588
|
+
if sp.issparse(T_raw):
|
|
589
|
+
T = T_raw
|
|
590
|
+
else:
|
|
591
|
+
T = sp.csr_matrix(T_raw)
|
|
592
|
+
break
|
|
593
|
+
|
|
594
|
+
if T is None:
|
|
595
|
+
# Last resort: use kNN connectivities
|
|
596
|
+
if "connectivities" in adata.obsp:
|
|
597
|
+
T = adata.obsp["connectivities"]
|
|
598
|
+
if verbose:
|
|
599
|
+
warnings.warn(
|
|
600
|
+
"velocity_graph not found. Using kNN connectivities as proxy.",
|
|
601
|
+
RuntimeWarning, stacklevel=2,
|
|
602
|
+
)
|
|
603
|
+
else:
|
|
604
|
+
warnings.warn(
|
|
605
|
+
"No velocity graph or connectivity matrix found. "
|
|
606
|
+
"Returning zero velocity vectors.",
|
|
607
|
+
RuntimeWarning, stacklevel=2,
|
|
608
|
+
)
|
|
609
|
+
return np.zeros(adata.n_obs), np.zeros(adata.n_obs)
|
|
610
|
+
|
|
611
|
+
# Row-normalize transition matrix
|
|
612
|
+
T = T.astype(float)
|
|
613
|
+
row_sums = np.array(T.sum(axis=1)).ravel()
|
|
614
|
+
row_sums[row_sums == 0] = 1.0
|
|
615
|
+
T_norm = sp.diags(1.0 / row_sums) @ T
|
|
616
|
+
|
|
617
|
+
# Expected position under transition
|
|
618
|
+
expected_coords = T_norm @ coords # (n_cells, 2)
|
|
619
|
+
V = expected_coords - coords # displacement = velocity
|
|
620
|
+
|
|
621
|
+
return V[:, 0], V[:, 1]
|