scCS-py 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scCS/embedding.py ADDED
@@ -0,0 +1,621 @@
1
+ """
2
+ embedding.py — Radial star embedding for scCS.
3
+
4
+ Constructs a custom 2D layout where:
5
+ - The bifurcation cluster (progenitor) sits at the origin (0, 0).
6
+ - Each terminal fate population occupies its own radial arm, evenly
7
+ spaced at 360/k degrees around the origin.
8
+ - Within each arm, cells are ordered along the radial axis by a
9
+ differentiation metric (pseudotime, CytoTRACE2, pathway score, etc.)
10
+ so that less-differentiated cells are close to the center and
11
+ more-differentiated cells are at the periphery.
12
+ - ONLY cells belonging to the bifurcation cluster or a terminal fate
13
+ are included. All other populations are excluded from the embedding.
14
+
15
+ The result is stored in adata_sub.obsm['X_sccs'] on the returned subset
16
+ AnnData, and looks like a star or sunburst when plotted — one arm per
17
+ fate, radiating from the progenitor.
18
+
19
+ Velocity projection
20
+ -------------------
21
+ RNA velocity vectors (from scVelo) are projected into this custom 2D
22
+ space by computing the transition-probability-weighted displacement of
23
+ each cell in the scCS coordinate system.
24
+
25
+ Differentiation metrics supported
26
+ ----------------------------------
27
+ - 'pseudotime' : scVelo velocity_pseudotime (default)
28
+ - 'cytotrace' : CytoTRACE2 score (column in adata.obs)
29
+ - 'custom' : any per-cell numeric column in adata.obs
30
+ - np.ndarray : directly supplied per-cell scores (shape n_cells,)
31
+
32
+ In all cases, higher score = more differentiated = farther from center.
33
+ If the metric is inverted (e.g., CytoTRACE2 where high = less
34
+ differentiated), pass invert_metric=True.
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import warnings
40
+ from typing import List, Optional, Tuple, Union
41
+
42
+ import numpy as np
43
+
44
+ try:
45
+ import scvelo as scv
46
+ _SCVELO_AVAILABLE = True
47
+ except ImportError:
48
+ _SCVELO_AVAILABLE = False
49
+
50
+ try:
51
+ import scanpy as sc
52
+ _SCANPY_AVAILABLE = True
53
+ except ImportError:
54
+ _SCANPY_AVAILABLE = False
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Public API
59
+ # ---------------------------------------------------------------------------
60
+
61
+ def build_star_embedding(
62
+ adata,
63
+ bifurcation_cluster: str,
64
+ terminal_cell_types: List[str],
65
+ cluster_key: str = "leiden",
66
+ differentiation_metric: Union[str, np.ndarray] = "pseudotime",
67
+ invert_metric: bool = False,
68
+ arm_scale: float = 10.0,
69
+ jitter: float = 0.3,
70
+ seed: int = 42,
71
+ ) -> "anndata.AnnData":
72
+ """Build the radial star embedding on a subset of adata.
73
+
74
+ Only cells belonging to the bifurcation cluster or a terminal fate
75
+ cluster are included. All other populations are excluded entirely.
76
+
77
+ Parameters
78
+ ----------
79
+ adata : AnnData
80
+ Full dataset. Will NOT be modified.
81
+ bifurcation_cluster : str
82
+ Label of the progenitor/bifurcation cluster in adata.obs[cluster_key].
83
+ These cells are placed at the origin.
84
+ terminal_cell_types : list of str
85
+ Labels of the k terminal fate populations. Each gets one radial arm.
86
+ cluster_key : str
87
+ Column in adata.obs with cluster labels.
88
+ differentiation_metric : str or np.ndarray
89
+ How to order cells along each arm:
90
+ - 'pseudotime' : uses adata.obs['velocity_pseudotime'] (computed if absent)
91
+ - 'cytotrace' : uses adata.obs['cytotrace2_score'] (must be pre-computed)
92
+ - any str : uses adata.obs[differentiation_metric] directly
93
+ - np.ndarray : per-cell scores, shape (n_cells,) for the FULL adata
94
+ Higher value = more differentiated = farther from center.
95
+ invert_metric : bool
96
+ If True, invert the metric so that high values map to the center
97
+ (use for metrics where high = less differentiated, e.g. raw CytoTRACE2).
98
+ arm_scale : float
99
+ Maximum radial distance (length of each arm).
100
+ jitter : float
101
+ Gaussian noise added perpendicular to each arm to avoid overplotting.
102
+ seed : int
103
+ Random seed for jitter.
104
+
105
+ Returns
106
+ -------
107
+ adata_sub : AnnData
108
+ Subset containing ONLY bifurcation + terminal fate cells.
109
+ Star embedding stored in adata_sub.obsm['X_sccs'].
110
+ Metadata stored in adata_sub.uns['sccs'].
111
+ """
112
+ import anndata
113
+
114
+ rng = np.random.default_rng(seed)
115
+ obs_labels_full = adata.obs[cluster_key].astype(str).values
116
+
117
+ # --- 0. Subset to relevant cells only ---
118
+ keep_labels = set([str(bifurcation_cluster)] + [str(f) for f in terminal_cell_types])
119
+ keep_mask = np.array([l in keep_labels for l in obs_labels_full])
120
+
121
+ if keep_mask.sum() == 0:
122
+ raise ValueError(
123
+ f"No cells found matching bifurcation_cluster='{bifurcation_cluster}' "
124
+ f"or terminal_cell_types={terminal_cell_types} in "
125
+ f"adata.obs['{cluster_key}']."
126
+ )
127
+
128
+ # --- Resolve differentiation metric on the FULL adata BEFORE subsetting ---
129
+ # This is critical for 'pseudotime': scVelo's velocity_pseudotime computation
130
+ # requires the intact neighbor/velocity graph, which breaks after subsetting.
131
+ # We resolve the metric on the full object, then slice to keep_mask.
132
+ metric_for_sub: np.ndarray # will always be a pre-resolved array after this block
133
+
134
+ if isinstance(differentiation_metric, np.ndarray):
135
+ arr = np.asarray(differentiation_metric, dtype=float).ravel()
136
+ if len(arr) != adata.n_obs:
137
+ raise ValueError(
138
+ f"Custom metric array has length {len(arr)}, "
139
+ f"expected {adata.n_obs} (full adata)."
140
+ )
141
+ metric_for_sub = arr[keep_mask]
142
+ else:
143
+ # Resolve on full adata (graph intact), then slice
144
+ scores_full = _resolve_metric(adata, differentiation_metric, invert_metric)
145
+ metric_for_sub = scores_full[keep_mask]
146
+
147
+ adata_sub = adata[keep_mask].copy()
148
+ obs_labels = adata_sub.obs[cluster_key].astype(str).values
149
+ n_cells = adata_sub.n_obs
150
+
151
+ print(f"[scCS] Subsetting: {keep_mask.sum()} / {adata.n_obs} cells kept")
152
+ print(f" ({adata.n_obs - keep_mask.sum()} cells from other populations excluded)")
153
+ for lbl in sorted(keep_labels):
154
+ n = (obs_labels == lbl).sum()
155
+ role = "progenitor" if lbl == str(bifurcation_cluster) else "fate"
156
+ print(f" {lbl}: {n} cells ({role})")
157
+
158
+ # --- 1. Use the pre-resolved metric (already sliced to subset) ---
159
+ # metric_for_sub is always a np.ndarray at this point (resolved above).
160
+ # _fill_nan handles any remaining NaNs; inversion was already applied.
161
+ scores = _fill_nan(np.asarray(metric_for_sub, dtype=float).ravel())
162
+
163
+ # --- 2. Compute arm directions (evenly spaced angles) ---
164
+ k = len(terminal_cell_types)
165
+ arm_angles_deg = np.linspace(0.0, 360.0, k, endpoint=False)
166
+ arm_angles_rad = np.radians(arm_angles_deg)
167
+ arm_dirs = np.stack([np.cos(arm_angles_rad), np.sin(arm_angles_rad)], axis=1) # (k, 2)
168
+
169
+ # --- 3. Assign each cell to an arm ---
170
+ # Bifurcation cells -> arm index -1 (origin)
171
+ # Terminal fate cells -> their arm index
172
+ arm_assignment = np.full(n_cells, -1, dtype=int)
173
+ for j, fate in enumerate(terminal_cell_types):
174
+ mask = obs_labels == str(fate)
175
+ arm_assignment[mask] = j
176
+
177
+ # --- 4. Compute per-arm score ranges for normalization ---
178
+ bif_mask_sub = obs_labels == str(bifurcation_cluster)
179
+ arm_score_ranges = []
180
+ for j, fate in enumerate(terminal_cell_types):
181
+ fate_mask = obs_labels == str(fate)
182
+ combined_mask = fate_mask | bif_mask_sub
183
+ if combined_mask.sum() > 0:
184
+ s = scores[combined_mask]
185
+ arm_score_ranges.append((s.min(), s.max()))
186
+ else:
187
+ arm_score_ranges.append((scores.min(), scores.max()))
188
+
189
+ # --- 5. Place cells in 2D ---
190
+ coords = np.zeros((n_cells, 2), dtype=float)
191
+
192
+ # Bifurcation cluster: cluster at origin with small jitter
193
+ n_bif = bif_mask_sub.sum()
194
+ if n_bif > 0:
195
+ coords[bif_mask_sub] = rng.normal(0.0, jitter * 0.5, size=(n_bif, 2))
196
+
197
+ # Fate cells: place along their assigned arm
198
+ for j in range(k):
199
+ cell_mask = arm_assignment == j
200
+ if cell_mask.sum() == 0:
201
+ continue
202
+
203
+ s_min, s_max = arm_score_ranges[j]
204
+ if s_max <= s_min:
205
+ r = np.linspace(0.0, arm_scale, cell_mask.sum())
206
+ else:
207
+ cell_scores_arm = scores[cell_mask]
208
+ r = (cell_scores_arm - s_min) / (s_max - s_min) * arm_scale
209
+ r = np.clip(r, 0.0, arm_scale)
210
+
211
+ arm_dir = arm_dirs[j]
212
+ positions = np.outer(r, arm_dir)
213
+ perp_dir = np.array([-arm_dir[1], arm_dir[0]])
214
+ perp_noise = rng.normal(0.0, jitter, size=cell_mask.sum())
215
+ positions += np.outer(perp_noise, perp_dir)
216
+ coords[cell_mask] = positions
217
+
218
+ # --- 6. Store in subset adata ---
219
+ adata_sub.obsm["X_sccs"] = coords
220
+
221
+ if "sccs" not in adata_sub.uns:
222
+ adata_sub.uns["sccs"] = {}
223
+ adata_sub.uns["sccs"]["arm_angles_deg"] = arm_angles_deg
224
+ adata_sub.uns["sccs"]["arm_dirs"] = arm_dirs
225
+ adata_sub.uns["sccs"]["arm_scale"] = arm_scale
226
+ adata_sub.uns["sccs"]["fate_names"] = [str(f) for f in terminal_cell_types]
227
+ adata_sub.uns["sccs"]["bifurcation_cluster"] = str(bifurcation_cluster)
228
+ adata_sub.uns["sccs"]["cluster_key"] = cluster_key
229
+ # Store integer indices of kept cells in the original adata (for velocity projection)
230
+ adata_sub.uns["sccs"]["parent_indices"] = np.where(keep_mask)[0]
231
+
232
+ adata_sub.obs["sccs_arm"] = arm_assignment
233
+ adata_sub.obs["sccs_arm_name"] = [
234
+ str(terminal_cell_types[a]) if a >= 0 else str(bifurcation_cluster)
235
+ for a in arm_assignment
236
+ ]
237
+
238
+ print(
239
+ f'\n[scCS] Star embedding built → adata_sub.obsm["X_sccs"] shape: {coords.shape}'
240
+ )
241
+ print(
242
+ f' Arm angles: '
243
+ + str({str(f): round(float(a), 1)
244
+ for f, a in zip(terminal_cell_types, arm_angles_deg)})
245
+ )
246
+
247
+ return adata_sub
248
+
249
+
250
+ def project_velocity_star(
251
+ adata_sub,
252
+ adata_full=None,
253
+ verbose: bool = True,
254
+ ) -> Tuple[np.ndarray, np.ndarray]:
255
+ """Project RNA velocity into the scCS star embedding space.
256
+
257
+ Uses the transition probability matrix from the full (unsubsetted) adata
258
+ to compute the expected displacement of each subset cell in the X_sccs
259
+ coordinate system.
260
+
261
+ This is necessary because subsetting breaks the velocity/neighbor graph
262
+ matrices (they retain full-dataset dimensions). We always use the full
263
+ graph and restrict to subset cell indices.
264
+
265
+ Parameters
266
+ ----------
267
+ adata_sub : AnnData
268
+ Subset returned by build_star_embedding(). Must have X_sccs in obsm
269
+ and a 'sccs_parent_indices' entry in uns (set automatically).
270
+ adata_full : AnnData, optional
271
+ The original full dataset with intact velocity_graph in uns.
272
+ If None, falls back to using adata_sub directly (only works if
273
+ velocity_graph was computed on the subset).
274
+
275
+ Returns
276
+ -------
277
+ vx, vy : np.ndarray, shape (n_sub_cells,)
278
+ Velocity components in the scCS embedding.
279
+ Also stored in adata_sub.obsm['velocity_sccs'].
280
+ """
281
+ if "X_sccs" not in adata_sub.obsm:
282
+ raise ValueError(
283
+ "X_sccs embedding not found. Run build_star_embedding() first."
284
+ )
285
+
286
+ coords_sub = np.array(adata_sub.obsm["X_sccs"]) # (n_sub, 2)
287
+ n_sub = adata_sub.n_obs
288
+
289
+ # Retrieve the parent indices (positions in full adata) stored during subsetting
290
+ parent_idx = adata_sub.uns.get("sccs", {}).get("parent_indices", None)
291
+
292
+ # ── Strategy 1: scVelo velocity_embedding on the full adata ──────────────
293
+ # Run on full adata, then slice to subset rows. This is the most accurate.
294
+ if _SCVELO_AVAILABLE and adata_full is not None and "velocity_graph" in adata_full.uns:
295
+ if verbose:
296
+ print("[scCS] Projecting velocity via scVelo on full adata → slicing to subset...")
297
+ try:
298
+ # Temporarily inject X_sccs into full adata for all cells.
299
+ # Subset cells get their star coords; other cells get zeros (ignored after slicing).
300
+ n_full = adata_full.n_obs
301
+ coords_full = np.zeros((n_full, 2), dtype=float)
302
+ if parent_idx is not None:
303
+ coords_full[parent_idx] = coords_sub
304
+ else:
305
+ # Fallback: match by obs_names
306
+ sub_names = set(adata_sub.obs_names)
307
+ full_names = list(adata_full.obs_names)
308
+ idx_map = [i for i, n in enumerate(full_names) if n in sub_names]
309
+ coords_full[idx_map] = coords_sub
310
+
311
+ adata_full.obsm["X_sccs_tmp"] = coords_full
312
+ scv.tl.velocity_embedding(adata_full, basis="sccs_tmp")
313
+ V_full = np.array(adata_full.obsm["velocity_sccs_tmp"]) # (n_full, 2)
314
+
315
+ # Slice to subset
316
+ if parent_idx is not None:
317
+ V_sub = V_full[parent_idx]
318
+ else:
319
+ V_sub = V_full[idx_map]
320
+
321
+ vx, vy = V_sub[:, 0], V_sub[:, 1]
322
+ adata_sub.obsm["velocity_sccs"] = V_sub
323
+
324
+ # Clean up temporary keys
325
+ del adata_full.obsm["X_sccs_tmp"]
326
+ if "velocity_sccs_tmp" in adata_full.obsm:
327
+ del adata_full.obsm["velocity_sccs_tmp"]
328
+
329
+ if verbose:
330
+ print(f"[scCS] Velocity projected. Shape: {V_sub.shape}")
331
+ return vx, vy
332
+
333
+ except Exception as e:
334
+ warnings.warn(
335
+ f"scVelo velocity_embedding on full adata failed ({e}). "
336
+ "Falling back to graph-based projection.",
337
+ RuntimeWarning,
338
+ stacklevel=2,
339
+ )
340
+
341
+ # ── Strategy 2: graph-based projection using full adata's velocity_graph ──
342
+ # Manually compute T[sub, :][:, sub] × coords_sub - coords_sub
343
+ if adata_full is not None and "velocity_graph" in adata_full.uns:
344
+ if verbose:
345
+ print("[scCS] Using graph-based projection from full velocity_graph...")
346
+ try:
347
+ import scipy.sparse as sp
348
+
349
+ T_full = adata_full.uns["velocity_graph"]
350
+ if not sp.issparse(T_full):
351
+ T_full = sp.csr_matrix(T_full)
352
+
353
+ if parent_idx is None:
354
+ sub_names = set(adata_sub.obs_names)
355
+ full_names = list(adata_full.obs_names)
356
+ parent_idx = np.array([i for i, n in enumerate(full_names) if n in sub_names])
357
+
358
+ # Extract sub × sub block of the transition matrix
359
+ T_sub = T_full[parent_idx, :][:, parent_idx] # (n_sub, n_sub)
360
+
361
+ # Row-normalize
362
+ row_sums = np.array(T_sub.sum(axis=1)).ravel()
363
+ row_sums[row_sums == 0] = 1.0
364
+ T_norm = sp.diags(1.0 / row_sums) @ T_sub
365
+
366
+ expected = T_norm @ coords_sub # (n_sub, 2)
367
+ V_sub = expected - coords_sub
368
+
369
+ vx, vy = V_sub[:, 0], V_sub[:, 1]
370
+ adata_sub.obsm["velocity_sccs"] = V_sub
371
+
372
+ if verbose:
373
+ print(f"[scCS] Graph-based velocity projected. Shape: {V_sub.shape}")
374
+ return vx, vy
375
+
376
+ except Exception as e:
377
+ warnings.warn(
378
+ f"Graph-based projection from full adata failed ({e}). "
379
+ "Falling back to subset-only projection.",
380
+ RuntimeWarning,
381
+ stacklevel=2,
382
+ )
383
+
384
+ # ── Strategy 3: last resort — use whatever graph is in adata_sub ─────────
385
+ if verbose:
386
+ warnings.warn(
387
+ "No full adata provided and no compatible velocity_graph found. "
388
+ "Using subset-only graph (may have dimension issues). "
389
+ "Pass adata_full=adata to project_velocity() for best results.",
390
+ RuntimeWarning,
391
+ stacklevel=2,
392
+ )
393
+ vx, vy = _graph_velocity_projection(adata_sub, coords_sub, verbose=verbose)
394
+ adata_sub.obsm["velocity_sccs"] = np.stack([vx, vy], axis=1)
395
+ return vx, vy
396
+
397
+
398
+ def run_velocity_pipeline(
399
+ adata,
400
+ mode: str = "dynamical",
401
+ n_top_genes: int = 2000,
402
+ n_pcs: int = 30,
403
+ n_neighbors: int = 30,
404
+ min_shared_counts: int = 20,
405
+ verbose: bool = True,
406
+ ) -> None:
407
+ """Run the full scVelo RNA velocity pipeline.
408
+
409
+ Requires spliced and unspliced count layers.
410
+
411
+ Parameters
412
+ ----------
413
+ adata : AnnData
414
+ Must contain layers 'spliced' and 'unspliced'.
415
+ mode : {'dynamical', 'stochastic', 'steady_state'}
416
+ n_top_genes : int
417
+ n_pcs : int
418
+ n_neighbors : int
419
+ min_shared_counts : int
420
+ verbose : bool
421
+ """
422
+ if not _SCVELO_AVAILABLE:
423
+ raise ImportError("scvelo is required. pip install scvelo")
424
+
425
+ missing = [l for l in ["spliced", "unspliced"] if l not in adata.layers]
426
+ if missing:
427
+ raise ValueError(
428
+ f"Missing required layers: {missing}. "
429
+ "These are generated by velocyto, STARsolo, or alevin-fry."
430
+ )
431
+
432
+ if verbose:
433
+ print(f"[scCS] Running scVelo pipeline (mode='{mode}')...")
434
+
435
+ scv.pp.filter_and_normalize(
436
+ adata, min_shared_counts=min_shared_counts,
437
+ n_top_genes=n_top_genes, log=True,
438
+ )
439
+
440
+ if "X_pca" not in adata.obsm and _SCANPY_AVAILABLE:
441
+ sc.tl.pca(adata, n_comps=n_pcs)
442
+ if "neighbors" not in adata.uns and _SCANPY_AVAILABLE:
443
+ sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs)
444
+
445
+ scv.pp.moments(adata, n_pcs=n_pcs, n_neighbors=n_neighbors)
446
+
447
+ if mode == "dynamical":
448
+ try:
449
+ scv.tl.recover_dynamics(adata, n_jobs=-1)
450
+ scv.tl.velocity(adata, mode="dynamical")
451
+ except Exception as e:
452
+ warnings.warn(
453
+ f"Dynamical model failed ({e}). Falling back to stochastic.",
454
+ RuntimeWarning, stacklevel=2,
455
+ )
456
+ scv.tl.velocity(adata, mode="stochastic")
457
+ else:
458
+ scv.tl.velocity(adata, mode=mode)
459
+
460
+ scv.tl.velocity_graph(adata)
461
+
462
+ try:
463
+ scv.tl.velocity_pseudotime(adata)
464
+ except Exception:
465
+ pass
466
+
467
+ if verbose:
468
+ print("[scCS] Velocity pipeline complete.")
469
+
470
+
471
+ # ---------------------------------------------------------------------------
472
+ # Internal helpers
473
+ # ---------------------------------------------------------------------------
474
+
475
+ def _resolve_metric(
476
+ adata,
477
+ metric: Union[str, np.ndarray],
478
+ invert: bool,
479
+ ) -> np.ndarray:
480
+ """Resolve differentiation metric to a per-cell float array."""
481
+ n_cells = adata.n_obs
482
+
483
+ if isinstance(metric, np.ndarray):
484
+ scores = np.asarray(metric, dtype=float).ravel()
485
+ if len(scores) != n_cells:
486
+ raise ValueError(
487
+ f"Custom metric array has length {len(scores)}, "
488
+ f"expected {n_cells}."
489
+ )
490
+
491
+ elif metric == "pseudotime":
492
+ if "velocity_pseudotime" not in adata.obs:
493
+ if _SCVELO_AVAILABLE and "velocity_graph" in adata.uns:
494
+ scv.tl.velocity_pseudotime(adata)
495
+ else:
496
+ warnings.warn(
497
+ "velocity_pseudotime not found and cannot be computed. "
498
+ "Falling back to uniform scores (random ordering).",
499
+ RuntimeWarning, stacklevel=3,
500
+ )
501
+ scores = np.random.default_rng(0).uniform(0, 1, n_cells)
502
+ if invert:
503
+ scores = 1.0 - scores
504
+ return _fill_nan(scores)
505
+ scores = np.array(adata.obs["velocity_pseudotime"], dtype=float)
506
+
507
+ elif metric == "cytotrace":
508
+ # CytoTRACE2: look for common column names
509
+ candidates = ["cytotrace2_score", "CytoTRACE2_Score", "cytotrace_score",
510
+ "CytoTRACE2", "cytotrace2"]
511
+ found = None
512
+ for c in candidates:
513
+ if c in adata.obs:
514
+ found = c
515
+ break
516
+ if found is None:
517
+ raise ValueError(
518
+ "CytoTRACE2 score not found in adata.obs. "
519
+ f"Expected one of: {candidates}. "
520
+ "Run CytoTRACE2 first or pass the column name as metric."
521
+ )
522
+ scores = np.array(adata.obs[found], dtype=float)
523
+ # CytoTRACE2: high score = stem-like = LESS differentiated
524
+ # So we invert by default unless user explicitly set invert=False
525
+ # We flip the invert flag here since CytoTRACE2 is naturally inverted
526
+ invert = not invert
527
+
528
+ else:
529
+ # Treat as column name in adata.obs
530
+ if metric not in adata.obs:
531
+ raise ValueError(
532
+ f"Column '{metric}' not found in adata.obs. "
533
+ f"Available columns: {list(adata.obs.columns)}"
534
+ )
535
+ scores = np.array(adata.obs[metric], dtype=float)
536
+
537
+ scores = _fill_nan(scores)
538
+
539
+ if invert:
540
+ scores = scores.max() - scores
541
+
542
+ return scores
543
+
544
+
545
+ def _fill_nan(scores: np.ndarray) -> np.ndarray:
546
+ """Replace NaN values with the column median."""
547
+ nan_mask = np.isnan(scores)
548
+ if nan_mask.any():
549
+ median = np.nanmedian(scores)
550
+ scores = scores.copy()
551
+ scores[nan_mask] = median
552
+ return scores
553
+
554
+
555
+
556
+ def _graph_velocity_projection(
557
+ adata,
558
+ coords: np.ndarray,
559
+ verbose: bool = True,
560
+ ) -> Tuple[np.ndarray, np.ndarray]:
561
+ """Fallback: project velocity using the velocity graph transition matrix.
562
+
563
+ For each cell i, the velocity vector is the weighted average displacement
564
+ to its neighbors, weighted by the transition probability T[i, j]:
565
+
566
+ v_i = sum_j T[i,j] * (x_j - x_i)
567
+
568
+ Parameters
569
+ ----------
570
+ adata : AnnData
571
+ coords : np.ndarray, shape (n_cells, 2)
572
+ verbose : bool
573
+
574
+ Returns
575
+ -------
576
+ vx, vy : np.ndarray, shape (n_cells,)
577
+ """
578
+ import scipy.sparse as sp
579
+
580
+ if verbose:
581
+ print("[scCS] Using graph-based velocity projection...")
582
+
583
+ # Try velocity_graph first, then connectivities as fallback
584
+ T = None
585
+ for key in ["velocity_graph", "velocity_graph_neg"]:
586
+ if key in adata.uns:
587
+ T_raw = adata.uns[key]
588
+ if sp.issparse(T_raw):
589
+ T = T_raw
590
+ else:
591
+ T = sp.csr_matrix(T_raw)
592
+ break
593
+
594
+ if T is None:
595
+ # Last resort: use kNN connectivities
596
+ if "connectivities" in adata.obsp:
597
+ T = adata.obsp["connectivities"]
598
+ if verbose:
599
+ warnings.warn(
600
+ "velocity_graph not found. Using kNN connectivities as proxy.",
601
+ RuntimeWarning, stacklevel=2,
602
+ )
603
+ else:
604
+ warnings.warn(
605
+ "No velocity graph or connectivity matrix found. "
606
+ "Returning zero velocity vectors.",
607
+ RuntimeWarning, stacklevel=2,
608
+ )
609
+ return np.zeros(adata.n_obs), np.zeros(adata.n_obs)
610
+
611
+ # Row-normalize transition matrix
612
+ T = T.astype(float)
613
+ row_sums = np.array(T.sum(axis=1)).ravel()
614
+ row_sums[row_sums == 0] = 1.0
615
+ T_norm = sp.diags(1.0 / row_sums) @ T
616
+
617
+ # Expected position under transition
618
+ expected_coords = T_norm @ coords # (n_cells, 2)
619
+ V = expected_coords - coords # displacement = velocity
620
+
621
+ return V[:, 0], V[:, 1]