scCS-py 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scCS/__init__.py ADDED
@@ -0,0 +1,154 @@
1
+ """
2
+ scCS — Single-cell Commitment Scores with radial star embedding.
3
+
4
+ Generalizes the 2-state commitment score framework from:
5
+
6
+ Kriukov et al. (2025) "Single-cell transcriptome of myeloid cells in
7
+ response to transplantation of human retinal neurons reveals reversibility
8
+ of microglial activation"
9
+
10
+ to any number of cell fates (k-furcations), with:
11
+ - User-supplied bifurcation cluster (e.g., leiden cluster '17')
12
+ - Radial star embedding: progenitor at origin, each fate on its own arm
13
+ - Cells ordered along arms by differentiation metric (pseudotime,
14
+ CytoTRACE2, pathway score, or any custom per-cell score)
15
+ - Population-level scores: unCS, nCS, commitment vector, entropy
16
+ - Per-cell fate affinity scores
17
+
18
+ Quick start
19
+ -----------
20
+ >>> import scCS
21
+ >>> scorer = scCS.CommitmentScorer(
22
+ ... adata,
23
+ ... bifurcation_cluster='17', # leiden cluster at the bifurcation
24
+ ... terminal_cell_types=['FateA', 'FateB', 'FateC'],
25
+ ... cluster_key='leiden',
26
+ ... )
27
+ >>> scorer.build_embedding(differentiation_metric='pseudotime')
28
+ >>> scorer.fit()
29
+ >>> result = scorer.score()
30
+ >>> print(result.summary())
31
+ >>> scorer.plot_star(result)
32
+
33
+ For k=2 (reproducing manuscript):
34
+ >>> scorer = scCS.CommitmentScorer(
35
+ ... adata,
36
+ ... bifurcation_cluster='17',
37
+ ... terminal_cell_types=['homeostatic', 'activated'],
38
+ ... cluster_key='leiden',
39
+ ... )
40
+ >>> scorer.build_embedding(differentiation_metric='pseudotime')
41
+ >>> scorer.fit()
42
+ >>> result = scorer.score()
43
+ >>> # result.pairwise_nCS[0, 1] should be ~8.066 (manuscript value)
44
+ """
45
+
46
+ __version__ = "0.3.2"
47
+ __author__ = "Emil Kriukov"
48
+
49
+ # Main API
50
+ from .trajectory import CommitmentScorer
51
+
52
+ # Fate map
53
+ from .bifurcation import FateMap, build_fate_map
54
+
55
+ # Embedding
56
+ from .embedding import (
57
+ build_star_embedding,
58
+ project_velocity_star,
59
+ run_velocity_pipeline,
60
+ )
61
+
62
+ # Core math (for advanced users)
63
+ from .scores import (
64
+ CommitmentScoreResult,
65
+ compute_magnitudes,
66
+ compute_angles,
67
+ bin_angles,
68
+ equal_sectors,
69
+ centroid_sectors,
70
+ compute_sector_magnitudes,
71
+ compute_unCS,
72
+ compute_nCS,
73
+ compute_commitment_vector,
74
+ # Entropy
75
+ compute_population_entropy, # aggregate velocity-mass entropy
76
+ compute_mean_cell_entropy, # mean per-cell k-way entropy
77
+ compute_per_fate_cell_entropy, # per-fate binary cell entropy, shape (k,)
78
+ compute_nn_cell_entropy, # NN-smoothed per-cell entropy, shape (n_cells,)
79
+ compute_commitment_entropy, # backward-compat alias for compute_population_entropy
80
+ compute_pairwise_cs_matrix,
81
+ compute_cell_scores,
82
+ )
83
+
84
+ # Driver genes
85
+ from .drivers import (
86
+ get_velocity_drivers,
87
+ get_deg_drivers,
88
+ )
89
+
90
+ # Pathway enrichment
91
+ from .enrichment import (
92
+ run_enrichment_per_fate,
93
+ export_enrichment_tables,
94
+ )
95
+
96
+ # Plotting
97
+ from .plot import (
98
+ plot_star_embedding,
99
+ plot_star_panels,
100
+ plot_rose,
101
+ plot_pairwise_cs,
102
+ plot_commitment_bar,
103
+ plot_commitment_heatmap,
104
+ plot_subset_comparison,
105
+ plot_expression_trends,
106
+ plot_nn_entropy_elbow,
107
+ )
108
+
109
+ __all__ = [
110
+ # Main class
111
+ "CommitmentScorer",
112
+ # Fate map
113
+ "FateMap",
114
+ "build_fate_map",
115
+ # Embedding
116
+ "build_star_embedding",
117
+ "project_velocity_star",
118
+ "run_velocity_pipeline",
119
+ # Results
120
+ "CommitmentScoreResult",
121
+ # Core math
122
+ "compute_magnitudes",
123
+ "compute_angles",
124
+ "bin_angles",
125
+ "equal_sectors",
126
+ "centroid_sectors",
127
+ "compute_sector_magnitudes",
128
+ "compute_unCS",
129
+ "compute_nCS",
130
+ "compute_commitment_vector",
131
+ "compute_population_entropy",
132
+ "compute_mean_cell_entropy",
133
+ "compute_per_fate_cell_entropy",
134
+ "compute_nn_cell_entropy",
135
+ "compute_commitment_entropy", # backward-compat alias
136
+ "compute_pairwise_cs_matrix",
137
+ "compute_cell_scores",
138
+ # Driver genes
139
+ "get_velocity_drivers",
140
+ "get_deg_drivers",
141
+ # Pathway enrichment
142
+ "run_enrichment_per_fate",
143
+ "export_enrichment_tables",
144
+ # Plots
145
+ "plot_star_embedding",
146
+ "plot_star_panels",
147
+ "plot_rose",
148
+ "plot_pairwise_cs",
149
+ "plot_commitment_bar",
150
+ "plot_commitment_heatmap",
151
+ "plot_subset_comparison",
152
+ "plot_expression_trends",
153
+ "plot_nn_entropy_elbow",
154
+ ]
scCS/bifurcation.py ADDED
@@ -0,0 +1,226 @@
1
+ """
2
+ bifurcation.py — Cluster-level fate map construction for scCS.
3
+
4
+ In scCS, the bifurcation point is explicitly defined by the user as a
5
+ single cluster (e.g., leiden cluster '17'). There is no automatic
6
+ fate detection — the user supplies:
7
+
8
+ bifurcation_cluster : the progenitor/root cluster label
9
+ terminal_cell_types : list of terminal fate cluster labels
10
+
11
+ This module builds a standardized FateMap from those labels, computing
12
+ centroids in the scCS star embedding space (X_sccs) and collecting
13
+ per-fate cell indices.
14
+
15
+ The FateMap is the single source of truth consumed by CommitmentScorer.score().
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional
23
+
24
+ import numpy as np
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # FateMap dataclass
29
+ # ---------------------------------------------------------------------------
30
+
31
+ @dataclass
32
+ class FateMap:
33
+ """Standardized description of k cell fates for commitment scoring.
34
+
35
+ Attributes
36
+ ----------
37
+ bifurcation_cluster : str
38
+ Label of the progenitor/root cluster supplied by the user.
39
+ fate_names : list of str
40
+ Human-readable labels for each terminal fate (length k).
41
+ fate_centroids : np.ndarray, shape (k, 2)
42
+ Mean 2D position of each fate's cells in the scCS embedding.
43
+ root_centroid : np.ndarray, shape (2,)
44
+ Mean 2D position of the bifurcation cluster cells.
45
+ In the scCS star embedding this is always near (0, 0).
46
+ root_cells : np.ndarray of int
47
+ Indices of bifurcation cluster cells in adata.
48
+ fate_cell_indices : list of np.ndarray
49
+ Per-fate arrays of cell indices.
50
+ arm_angles_deg : np.ndarray, shape (k,)
51
+ Angle (degrees) of each fate's radial arm in the star embedding.
52
+ cluster_key : str
53
+ The obs column used for cluster labels.
54
+ k : int
55
+ Number of fates (read-only property).
56
+ """
57
+ bifurcation_cluster: str
58
+ fate_names: List[str]
59
+ fate_centroids: np.ndarray
60
+ root_centroid: np.ndarray
61
+ root_cells: np.ndarray
62
+ fate_cell_indices: List[np.ndarray]
63
+ arm_angles_deg: np.ndarray
64
+ cluster_key: str
65
+
66
+ @property
67
+ def k(self) -> int:
68
+ return len(self.fate_names)
69
+
70
+ def summary(self) -> str:
71
+ lines = [
72
+ f"FateMap (bifurcation_cluster='{self.bifurcation_cluster}', k={self.k})",
73
+ f" Cluster key : '{self.cluster_key}'",
74
+ f" Root cells : {len(self.root_cells)}",
75
+ f" Root centroid: ({self.root_centroid[0]:.3f}, {self.root_centroid[1]:.3f})",
76
+ ]
77
+ for j, name in enumerate(self.fate_names):
78
+ n = len(self.fate_cell_indices[j])
79
+ c = self.fate_centroids[j]
80
+ a = self.arm_angles_deg[j]
81
+ lines.append(
82
+ f" Fate {j}: '{name}' n_cells={n} "
83
+ f"centroid=({c[0]:.2f}, {c[1]:.2f}) arm_angle={a:.1f}°"
84
+ )
85
+ return "\n".join(lines)
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # FateMap construction
90
+ # ---------------------------------------------------------------------------
91
+
92
+ def build_fate_map(
93
+ adata,
94
+ bifurcation_cluster: str,
95
+ terminal_cell_types: List[str],
96
+ cluster_key: str = "leiden",
97
+ verbose: bool = True,
98
+ ) -> FateMap:
99
+ """Build a FateMap from user-supplied cluster labels.
100
+
101
+ This is the only fate-detection strategy in scCS. The user explicitly
102
+ names the bifurcation cluster and all terminal fate clusters.
103
+
104
+ Parameters
105
+ ----------
106
+ adata : AnnData
107
+ Must have X_sccs in obsm (built by build_star_embedding).
108
+ bifurcation_cluster : str
109
+ Label of the progenitor cluster in adata.obs[cluster_key].
110
+ Example: '17' (leiden cluster 17)
111
+ terminal_cell_types : list of str
112
+ Labels of the k terminal fate clusters.
113
+ Example: ['Monocyte', 'DC', 'Neutrophil']
114
+ cluster_key : str
115
+ Column in adata.obs with cluster labels.
116
+ verbose : bool
117
+
118
+ Returns
119
+ -------
120
+ FateMap
121
+ """
122
+ if "X_sccs" not in adata.obsm:
123
+ raise ValueError(
124
+ "X_sccs embedding not found in adata.obsm. "
125
+ "Run CommitmentScorer.build_embedding() before build_fate_map()."
126
+ )
127
+
128
+ obs_labels = adata.obs[cluster_key].astype(str).values
129
+ embedding = np.array(adata.obsm["X_sccs"])
130
+
131
+ # --- Validate bifurcation cluster ---
132
+ bif_mask = obs_labels == str(bifurcation_cluster)
133
+ if bif_mask.sum() == 0:
134
+ available = sorted(set(obs_labels))
135
+ raise ValueError(
136
+ f"Bifurcation cluster '{bifurcation_cluster}' not found in "
137
+ f"adata.obs['{cluster_key}']. "
138
+ f"Available labels: {available}"
139
+ )
140
+ root_cells = np.where(bif_mask)[0]
141
+ root_centroid = embedding[root_cells].mean(axis=0)
142
+
143
+ if verbose:
144
+ print(
145
+ f"[scCS] Bifurcation cluster '{bifurcation_cluster}': "
146
+ f"{len(root_cells)} cells, "
147
+ f"centroid=({root_centroid[0]:.2f}, {root_centroid[1]:.2f})"
148
+ )
149
+
150
+ # --- Validate and collect terminal fates ---
151
+ fate_names = []
152
+ fate_centroids = []
153
+ fate_cell_indices = []
154
+ skipped = []
155
+
156
+ for name in terminal_cell_types:
157
+ mask = obs_labels == str(name)
158
+ n = mask.sum()
159
+ if n == 0:
160
+ warnings.warn(
161
+ f"Terminal fate '{name}' not found in adata.obs['{cluster_key}']. "
162
+ "Skipping.",
163
+ stacklevel=2,
164
+ )
165
+ skipped.append(name)
166
+ continue
167
+ idx = np.where(mask)[0]
168
+ fate_names.append(str(name))
169
+ fate_cell_indices.append(idx)
170
+ fate_centroids.append(embedding[idx].mean(axis=0))
171
+
172
+ if verbose:
173
+ c = embedding[idx].mean(axis=0)
174
+ print(f"[scCS] Fate '{name}': {n} cells, centroid=({c[0]:.2f}, {c[1]:.2f})")
175
+
176
+ if len(fate_names) == 0:
177
+ raise ValueError(
178
+ "No valid terminal fate clusters found. "
179
+ f"Skipped: {skipped}"
180
+ )
181
+
182
+ if skipped:
183
+ warnings.warn(
184
+ f"Skipped {len(skipped)} fate(s) not found in data: {skipped}",
185
+ stacklevel=2,
186
+ )
187
+
188
+ fate_centroids = np.array(fate_centroids)
189
+
190
+ # --- Retrieve arm angles from embedding metadata ---
191
+ # build_star_embedding stores these in adata.uns['sccs']
192
+ sccs_meta = adata.uns.get("sccs", {})
193
+ stored_fates = sccs_meta.get("fate_names", [])
194
+ stored_angles = sccs_meta.get("arm_angles_deg", None)
195
+
196
+ arm_angles_deg = np.zeros(len(fate_names))
197
+ if stored_angles is not None and len(stored_fates) == len(stored_angles):
198
+ fate_to_angle = dict(zip(stored_fates, stored_angles))
199
+ for j, name in enumerate(fate_names):
200
+ if name in fate_to_angle:
201
+ arm_angles_deg[j] = fate_to_angle[name]
202
+ else:
203
+ # Compute from centroid direction
204
+ delta = fate_centroids[j] - root_centroid
205
+ arm_angles_deg[j] = np.degrees(np.arctan2(delta[1], delta[0])) % 360.0
206
+ else:
207
+ # Compute from centroid directions
208
+ for j in range(len(fate_names)):
209
+ delta = fate_centroids[j] - root_centroid
210
+ arm_angles_deg[j] = np.degrees(np.arctan2(delta[1], delta[0])) % 360.0
211
+
212
+ fate_map = FateMap(
213
+ bifurcation_cluster=str(bifurcation_cluster),
214
+ fate_names=fate_names,
215
+ fate_centroids=fate_centroids,
216
+ root_centroid=root_centroid,
217
+ root_cells=root_cells,
218
+ fate_cell_indices=fate_cell_indices,
219
+ arm_angles_deg=arm_angles_deg,
220
+ cluster_key=cluster_key,
221
+ )
222
+
223
+ if verbose:
224
+ print(f"[scCS] FateMap built: k={fate_map.k} fates")
225
+
226
+ return fate_map
scCS/drivers.py ADDED
@@ -0,0 +1,237 @@
1
+ """
2
+ drivers.py — Driver gene identification for scCS fate arms.
3
+
4
+ Two complementary strategies:
5
+
6
+ 1. Velocity-based drivers
7
+ For each fate arm, rank genes by their mean scVelo velocity in arm cells.
8
+ High positive velocity = gene is being actively upregulated along that fate.
9
+ Requires the 'velocity' layer (from scVelo pipeline).
10
+
11
+ 2. DEG-based drivers
12
+ For each fate arm, run a Wilcoxon rank-sum test comparing arm cells vs
13
+ the bifurcation (progenitor) cluster. Returns logFC and adjusted p-value
14
+ per gene, with a significance flag.
15
+
16
+ Both functions operate on adata_sub (the subset returned by build_star_embedding),
17
+ which contains only bifurcation + terminal fate cells.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import warnings
23
+ from typing import Dict, List, Optional
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # 1. Velocity-based driver genes
31
+ # ---------------------------------------------------------------------------
32
+
33
+ def get_velocity_drivers(
34
+ adata_sub,
35
+ fate_names: List[str],
36
+ cluster_key: str,
37
+ bifurcation_cluster: str,
38
+ n_top: int = 50,
39
+ ) -> Dict[str, pd.DataFrame]:
40
+ """Rank genes by mean scVelo velocity in each fate arm's cells.
41
+
42
+ Parameters
43
+ ----------
44
+ adata_sub : AnnData
45
+ Subset containing only bifurcation + terminal fate cells.
46
+ Must have the 'velocity' layer (from scVelo).
47
+ fate_names : list of str
48
+ Terminal fate cluster labels.
49
+ cluster_key : str
50
+ Column in adata_sub.obs with cluster labels.
51
+ bifurcation_cluster : str
52
+ Label of the progenitor cluster (used for context only).
53
+ n_top : int
54
+ Number of top driver genes to print per fate.
55
+
56
+ Returns
57
+ -------
58
+ dict : fate_name -> DataFrame with columns [gene, mean_velocity, rank]
59
+ Sorted by mean_velocity descending (most upregulated first).
60
+ """
61
+ if "velocity" not in adata_sub.layers:
62
+ raise ValueError(
63
+ "'velocity' layer not found in adata_sub. "
64
+ "Run the scVelo pipeline first (scorer.compute_velocity() or "
65
+ "scvelo.tl.velocity())."
66
+ )
67
+
68
+ import scipy.sparse as sp
69
+
70
+ V_genes = adata_sub.layers["velocity"]
71
+ if sp.issparse(V_genes):
72
+ V_genes = V_genes.toarray()
73
+ V_genes = np.asarray(V_genes, dtype=float) # (n_cells, n_genes)
74
+
75
+ genes = adata_sub.var_names
76
+ obs_labels = adata_sub.obs[cluster_key].astype(str).values
77
+ results: Dict[str, pd.DataFrame] = {}
78
+
79
+ for name in fate_names:
80
+ mask = obs_labels == str(name)
81
+ if mask.sum() == 0:
82
+ warnings.warn(
83
+ f"No cells found for fate '{name}' in adata_sub. Skipping.",
84
+ stacklevel=2,
85
+ )
86
+ continue
87
+
88
+ V_fate = V_genes[mask, :] # (n_fate_cells, n_genes)
89
+
90
+ with warnings.catch_warnings():
91
+ warnings.simplefilter("ignore")
92
+ mean_vel = np.nanmean(V_fate, axis=0)
93
+
94
+ df = pd.DataFrame({
95
+ "gene": genes,
96
+ "mean_velocity": mean_vel,
97
+ }).dropna(subset=["mean_velocity"])
98
+
99
+ df = df.sort_values("mean_velocity", ascending=False).reset_index(drop=True)
100
+ df["rank"] = df.index + 1
101
+ results[name] = df
102
+
103
+ print(f"\n── Velocity drivers: {name} (top {n_top}) ──")
104
+ print(
105
+ df.head(n_top)[["rank", "gene", "mean_velocity"]]
106
+ .to_string(index=False)
107
+ )
108
+
109
+ return results
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # 2. DEG-based driver genes
114
+ # ---------------------------------------------------------------------------
115
+
116
+ def get_deg_drivers(
117
+ adata_sub,
118
+ fate_names: List[str],
119
+ cluster_key: str,
120
+ bifurcation_cluster: str,
121
+ n_top: int = 50,
122
+ pval_cutoff: float = 0.05,
123
+ logfc_cutoff: float = 0.25,
124
+ ) -> Dict[str, pd.DataFrame]:
125
+ """Find DEGs for each fate arm vs the bifurcation cluster (Wilcoxon).
126
+
127
+ For each fate arm, compares arm cells against progenitor (bifurcation)
128
+ cells using a Wilcoxon rank-sum test via scanpy.
129
+
130
+ Parameters
131
+ ----------
132
+ adata_sub : AnnData
133
+ Subset containing only bifurcation + terminal fate cells.
134
+ fate_names : list of str
135
+ Terminal fate cluster labels.
136
+ cluster_key : str
137
+ Column in adata_sub.obs with cluster labels.
138
+ bifurcation_cluster : str
139
+ Label of the progenitor cluster (reference group).
140
+ n_top : int
141
+ Number of top significant DEGs to print per fate.
142
+ pval_cutoff : float
143
+ Adjusted p-value threshold for significance.
144
+ logfc_cutoff : float
145
+ Minimum absolute log fold-change for significance.
146
+
147
+ Returns
148
+ -------
149
+ dict : fate_name -> DataFrame with columns:
150
+ [gene, logfoldchange, pval, pval_adj, significant]
151
+ Sorted by logfoldchange descending.
152
+ """
153
+ try:
154
+ import scanpy as sc
155
+ except ImportError:
156
+ raise ImportError("scanpy is required for DEG analysis. pip install scanpy")
157
+
158
+ obs_labels = adata_sub.obs[cluster_key].astype(str).values
159
+ results: Dict[str, pd.DataFrame] = {}
160
+
161
+ for name in fate_names:
162
+ fate_mask = obs_labels == str(name)
163
+ bif_mask = obs_labels == str(bifurcation_cluster)
164
+ sub_mask = fate_mask | bif_mask
165
+
166
+ n_fate = fate_mask.sum()
167
+ n_bif = bif_mask.sum()
168
+
169
+ if n_fate < 5:
170
+ warnings.warn(
171
+ f"Fate '{name}' has only {n_fate} cells. "
172
+ "Skipping DEG analysis (need ≥5).",
173
+ stacklevel=2,
174
+ )
175
+ continue
176
+ if n_bif < 5:
177
+ warnings.warn(
178
+ f"Bifurcation cluster '{bifurcation_cluster}' has only {n_bif} cells. "
179
+ "Skipping DEG analysis (need ≥5).",
180
+ stacklevel=2,
181
+ )
182
+ continue
183
+
184
+ # Subset to fate + progenitor only for this pairwise comparison
185
+ adata_pair = adata_sub[sub_mask].copy()
186
+ adata_pair.obs["_deg_group"] = [
187
+ name if l == str(name) else "progenitor"
188
+ for l in adata_pair.obs[cluster_key].astype(str)
189
+ ]
190
+
191
+ try:
192
+ sc.tl.rank_genes_groups(
193
+ adata_pair,
194
+ groupby="_deg_group",
195
+ groups=[name],
196
+ reference="progenitor",
197
+ method="wilcoxon",
198
+ key_added="rank_genes",
199
+ pts=True,
200
+ )
201
+ except Exception as e:
202
+ warnings.warn(
203
+ f"rank_genes_groups failed for fate '{name}': {e}",
204
+ stacklevel=2,
205
+ )
206
+ continue
207
+
208
+ rg = adata_pair.uns["rank_genes"]
209
+ df = pd.DataFrame({
210
+ "gene": rg["names"][name],
211
+ "logfoldchange": rg["logfoldchanges"][name],
212
+ "pval": rg["pvals"][name],
213
+ "pval_adj": rg["pvals_adj"][name],
214
+ })
215
+ df["significant"] = (
216
+ (df["pval_adj"] < pval_cutoff)
217
+ & (df["logfoldchange"].abs() > logfc_cutoff)
218
+ )
219
+ df = df.sort_values("logfoldchange", ascending=False).reset_index(drop=True)
220
+ results[name] = df
221
+
222
+ n_sig = df["significant"].sum()
223
+ n_up = ((df["logfoldchange"] > logfc_cutoff) & df["significant"]).sum()
224
+ n_dn = ((df["logfoldchange"] < -logfc_cutoff) & df["significant"]).sum()
225
+
226
+ print(f"\n── DEG drivers: {name} vs progenitor ──")
227
+ print(f" Significant: {n_sig} (up: {n_up}, down: {n_dn})")
228
+ sig_df = df[df["significant"]].head(n_top)
229
+ if len(sig_df) > 0:
230
+ print(
231
+ sig_df[["gene", "logfoldchange", "pval_adj"]]
232
+ .to_string(index=False)
233
+ )
234
+ else:
235
+ print(" (no significant DEGs at current thresholds)")
236
+
237
+ return results