py-TranspaceR 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,199 @@
1
+ """Plotting functions: spatial visualization, heatmaps, UMAP, etc."""
2
+
3
+ import numpy as np
4
+
5
+ try:
6
+ import matplotlib
7
+ matplotlib.use("Agg")
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.colors as mcolors
10
+ HAS_MPL = True
11
+ except ImportError:
12
+ HAS_MPL = False
13
+
14
+
15
+ def plot_fov(x, y, cluster_labels, output_path=None, title="FoV"):
16
+ """Scatter plot of cell centroids colored by cluster.
17
+
18
+ Parameters
19
+ ----------
20
+ x, y : np.ndarray
21
+ Cell coordinates.
22
+ cluster_labels : array-like
23
+ Cluster labels.
24
+ output_path : str, optional
25
+ Save path.
26
+ title : str
27
+ Plot title.
28
+ """
29
+ if not HAS_MPL:
30
+ return
31
+ fig, ax = plt.subplots(figsize=(8, 8))
32
+ unique_labels = np.unique(cluster_labels)
33
+ cmap = plt.cm.get_cmap("tab20", len(unique_labels))
34
+ for i, label in enumerate(unique_labels):
35
+ mask = cluster_labels == label
36
+ ax.scatter(x[mask], y[mask], s=1, c=[cmap(i)], label=str(label), alpha=0.7)
37
+ ax.set_title(title)
38
+ ax.set_aspect("equal")
39
+ ax.invert_yaxis()
40
+ if output_path:
41
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
42
+ plt.close(fig)
43
+
44
+
45
+ def plot_fov_gene(x, y, expression_values, gene_name, output_path=None):
46
+ """Scatter plot colored by gene expression (white-yellow-orange-red)."""
47
+ if not HAS_MPL:
48
+ return
49
+ from .utils import color_convertion
50
+ colors = color_convertion(expression_values)
51
+ fig, ax = plt.subplots(figsize=(8, 8))
52
+ ax.scatter(x, y, s=1, c=colors, alpha=0.7)
53
+ ax.set_title(gene_name)
54
+ ax.set_aspect("equal")
55
+ ax.invert_yaxis()
56
+ if output_path:
57
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
58
+ plt.close(fig)
59
+
60
+
61
+ def plot_variogram(sill, range_val, alpha, nugget, model, r_data, v_data, r2, output_path=None):
62
+ """Plot fitted variogram curve over data points."""
63
+ if not HAS_MPL:
64
+ return
65
+ fig, ax = plt.subplots(figsize=(8, 6))
66
+ ax.scatter(r_data, v_data, s=10, alpha=0.5, label="Data")
67
+
68
+ r_fit = np.linspace(0, max(r_data) * 1.1, 100)
69
+ if model == "Constant":
70
+ ax.axhline(sill, color="red", label=f"Constant (C={sill:.3f})")
71
+ elif model == "Exponential":
72
+ v_fit = sill * (1 - np.exp(-r_fit / range_val)) + nugget
73
+ ax.plot(r_fit, v_fit, "r-", label=f"Exp (C={sill:.3f}, tau={range_val:.3f})")
74
+ elif model == "Finetuned exponential":
75
+ v_fit = sill * (1 - np.exp(-r_fit**alpha / range_val)) + nugget
76
+ ax.plot(r_fit, v_fit, "r-", label=f"FineExp (alpha={alpha:.3f})")
77
+
78
+ ax.set_xlabel("Distance")
79
+ ax.set_ylabel("Variogram")
80
+ ax.set_title(f"R2={r2:.3f}, Model={model}")
81
+ ax.legend()
82
+ if output_path:
83
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
84
+ plt.close(fig)
85
+
86
+
87
+ def save_umap(umap_coords, labels, output_path, title="UMAP"):
88
+ """Save UMAP scatter plot."""
89
+ if not HAS_MPL:
90
+ return
91
+ fig, ax = plt.subplots(figsize=(8, 8))
92
+ unique = np.unique(labels)
93
+ cmap = plt.cm.get_cmap("tab20", len(unique))
94
+ for i, l in enumerate(unique):
95
+ mask = labels == l
96
+ ax.scatter(umap_coords[mask, 0], umap_coords[mask, 1], s=1, c=[cmap(i)], label=str(l), alpha=0.5)
97
+ ax.set_title(title)
98
+ ax.legend(markerscale=10, fontsize=6)
99
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
100
+ plt.close(fig)
101
+
102
+
103
+ def save_heatmap_markers(expression, gene_names, cluster_labels, output_path,
104
+ top_n: int = 5):
105
+ """Save marker gene heatmap per cluster."""
106
+ if not HAS_MPL:
107
+ return
108
+ groups = np.unique(cluster_labels)
109
+ markers = []
110
+ for g in groups:
111
+ mask = cluster_labels == g
112
+ means_in = expression[mask, :].mean(axis=0)
113
+ means_out = expression[~mask, :].mean(axis=0) if np.any(~mask) else np.zeros_like(means_in)
114
+ fc = means_in - means_out
115
+ top_idx = np.argsort(-fc)[:top_n]
116
+ markers.extend(top_idx.tolist())
117
+ markers = sorted(set(markers))
118
+
119
+ fig, ax = plt.subplots(figsize=(12, 8))
120
+ data = expression[:, markers]
121
+ im = ax.imshow(data, aspect="auto", cmap="viridis")
122
+ ax.set_yticks(range(len(groups)))
123
+ ax.set_yticklabels(groups)
124
+ plt.colorbar(im)
125
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
126
+ plt.close(fig)
127
+
128
+
129
+ def save_annotation_plot(cluster_labels, cell_types, output_path):
130
+ """Bar plot of cell type proportions per cluster."""
131
+ if not HAS_MPL:
132
+ return
133
+ import pandas as pd
134
+ df = pd.DataFrame({"cluster": cluster_labels, "type": cell_types})
135
+ ct = pd.crosstab(df["cluster"], df["type"], normalize="index")
136
+ fig, ax = plt.subplots(figsize=(10, 6))
137
+ ct.plot(kind="bar", stacked=True, ax=ax)
138
+ ax.set_ylabel("Proportion")
139
+ ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=6)
140
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
141
+ plt.close(fig)
142
+
143
+
144
+ def save_boxplot(expression, gene_idx, gene_name, cluster_labels, output_path):
145
+ """Gene expression boxplot by cluster."""
146
+ if not HAS_MPL:
147
+ return
148
+ groups = np.unique(cluster_labels)
149
+ data = [expression[cluster_labels == g, gene_idx] for g in groups]
150
+ fig, ax = plt.subplots(figsize=(8, 6))
151
+ ax.boxplot(data, labels=groups)
152
+ ax.set_title(gene_name)
153
+ ax.set_ylabel("Expression")
154
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
155
+ plt.close(fig)
156
+
157
+
158
+ def save_dendogram(data, labels, output_path):
159
+ """Hierarchical clustering dendrogram."""
160
+ if not HAS_MPL:
161
+ return
162
+ from scipy.cluster.hierarchy import linkage, dendrogram
163
+ Z = linkage(data, method="ward")
164
+ fig, ax = plt.subplots(figsize=(12, 6))
165
+ dendrogram(Z, labels=labels, ax=ax, leaf_rotation=90, leaf_font_size=6)
166
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
167
+ plt.close(fig)
168
+
169
+
170
+ def save_geary_variance_plot(geary_values, variance_values, selected_mask,
171
+ output_path, title="Geary vs Variance"):
172
+ """Scatter of variance vs Geary score."""
173
+ if not HAS_MPL:
174
+ return
175
+ fig, ax = plt.subplots(figsize=(8, 6))
176
+ ax.scatter(geary_values, variance_values, s=5, alpha=0.3, c="black")
177
+ if selected_mask is not None and np.any(selected_mask):
178
+ ax.scatter(geary_values[selected_mask], variance_values[selected_mask],
179
+ s=10, c="red", label="Selected")
180
+ ax.set_xlabel("Geary Score")
181
+ ax.set_ylabel("Variance")
182
+ ax.set_title(title)
183
+ ax.legend()
184
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
185
+ plt.close(fig)
186
+
187
+
188
+ def save_tissue_visualization(x, y, values, output_path, title="Tissue"):
189
+ """Spatial tissue visualization."""
190
+ if not HAS_MPL:
191
+ return
192
+ fig, ax = plt.subplots(figsize=(8, 8))
193
+ sc = ax.scatter(x, y, s=1, c=values, cmap="viridis", alpha=0.7)
194
+ plt.colorbar(sc)
195
+ ax.set_title(title)
196
+ ax.set_aspect("equal")
197
+ ax.invert_yaxis()
198
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
199
+ plt.close(fig)
transspacer/qc.py ADDED
@@ -0,0 +1,74 @@
1
+ """Quality control: Otsu thresholding and QC gene/RNA thresholds."""
2
+
3
+ import numpy as np
4
+
5
+
6
+ def otsu_thresholding(x: np.ndarray, number_bins: int = 100) -> float:
7
+ """Otsu's method to find optimal threshold minimizing intra-class variance.
8
+
9
+ Parameters
10
+ ----------
11
+ x : np.ndarray
12
+ 1D data vector.
13
+ number_bins : int
14
+ Number of quantile-based bins.
15
+
16
+ Returns
17
+ -------
18
+ float
19
+ Selected threshold value.
20
+ """
21
+ list_bin = np.quantile(x, np.linspace(0, 1, number_bins))
22
+ intra_variance = np.empty(number_bins)
23
+ for k in range(number_bins):
24
+ threshold = list_bin[k]
25
+ below = x[x < threshold]
26
+ above = x[x > threshold]
27
+ s = 0.0
28
+ if len(below) > 1:
29
+ s += len(below) * np.var(below, ddof=1)
30
+ if len(above) > 1:
31
+ s += len(above) * np.var(above, ddof=1)
32
+ intra_variance[k] = s
33
+ return list_bin[np.argmin(intra_variance)]
34
+
35
+
36
+ def qc_gene_threshold(expression, meta_data_radius, gene_names):
37
+ """QC gene threshold using Otsu on log10 gene counts.
38
+
39
+ Parameters
40
+ ----------
41
+ expression : sparse or dense matrix
42
+ Gene expression (cells x genes).
43
+ meta_data_radius : np.ndarray
44
+ Cell radii.
45
+ gene_names : list of str
46
+ Gene names.
47
+
48
+ Returns
49
+ -------
50
+ dict
51
+ Otsu_threshold, Separability_score.
52
+ """
53
+ from scipy.sparse import issparse
54
+
55
+ if issparse(expression):
56
+ gene_size = np.array(expression.sum(axis=0)).flatten()
57
+ else:
58
+ gene_size = expression.sum(axis=0)
59
+
60
+ neg_mask = np.array([
61
+ any(pat in str(g) for pat in ["Neg", "False", "Blank"])
62
+ for g in gene_names
63
+ ])
64
+ neg_counts = np.log10(gene_size[neg_mask])
65
+ neg_counts = neg_counts[np.isfinite(neg_counts)]
66
+
67
+ x = np.log10(gene_size)
68
+ x = x[np.isfinite(x)]
69
+
70
+ threshold = otsu_thresholding(x)
71
+ neg_before = np.sum(neg_counts < threshold)
72
+ score = neg_before / len(neg_counts) if len(neg_counts) > 0 else 0.0
73
+
74
+ return {"Otsu_threshold": 10**threshold, "Separability_score": score}
@@ -0,0 +1,67 @@
1
+ """Sparse matrix utilities: column variance, group-wise aggregation."""
2
+
3
+ import numpy as np
4
+ from scipy.sparse import issparse, csc_matrix
5
+
6
+
7
+ def colvars_sparse(mat) -> np.ndarray:
8
+ """Column variance of a sparse matrix (CSC format).
9
+
10
+ Parameters
11
+ ----------
12
+ mat : scipy.sparse.csc_matrix or np.ndarray
13
+ Input matrix.
14
+
15
+ Returns
16
+ -------
17
+ np.ndarray
18
+ Variance of each column.
19
+ """
20
+ if issparse(mat):
21
+ mat = csc_matrix(mat)
22
+ n_cells = mat.shape[0]
23
+ means = np.array(mat.mean(axis=0)).flatten()
24
+ variances = np.zeros(mat.shape[1])
25
+ for j in range(mat.shape[1]):
26
+ start = mat.indptr[j]
27
+ end = mat.indptr[j + 1]
28
+ if start < end:
29
+ diff_sq = (mat.data[start:end] - means[j]) ** 2
30
+ sum_sq = np.sum(diff_sq)
31
+ else:
32
+ sum_sq = 0.0
33
+ variances[j] = (sum_sq + (n_cells - (end - start)) * means[j] ** 2) / n_cells
34
+ return variances
35
+ else:
36
+ return np.var(mat, axis=0)
37
+
38
+
39
+ def aggregate_sparse(expression, grouping) -> np.ndarray:
40
+ """Group-wise column means on a matrix.
41
+
42
+ Parameters
43
+ ----------
44
+ expression : sparse or dense matrix
45
+ Expression matrix (cells x genes).
46
+ grouping : array-like
47
+ Cluster/group labels per cell.
48
+
49
+ Returns
50
+ -------
51
+ np.ndarray
52
+ Mean expression per group (genes x groups), matching R's sapply output.
53
+ """
54
+ groups = np.unique(grouping)
55
+ if issparse(expression):
56
+ result = np.zeros((expression.shape[1], len(groups)))
57
+ for i, g in enumerate(groups):
58
+ idx = np.where(grouping == g)[0]
59
+ if len(idx) > 0:
60
+ result[:, i] = np.array(expression[idx, :].mean(axis=0)).flatten()
61
+ else:
62
+ result = np.zeros((expression.shape[1], len(groups)))
63
+ for i, g in enumerate(groups):
64
+ idx = np.where(grouping == g)[0]
65
+ if len(idx) > 0:
66
+ result[:, i] = expression[idx, :].mean(axis=0)
67
+ return result
@@ -0,0 +1,325 @@
1
+ """Spatial statistics: Geary's C, excess variance ratio, excess zero score."""
2
+
3
+ import numpy as np
4
+ from scipy.sparse import issparse, csc_matrix
5
+ from scipy.spatial import Delaunay
6
+ from scipy.sparse.csgraph import connected_components
7
+ from scipy.stats import norm, kurtosis as scipy_kurtosis
8
+ from scipy.optimize import curve_fit
9
+ from scipy.stats import pearsonr
10
+
11
+
12
+ def _build_spatial_weight_matrix(coords: np.ndarray):
13
+ """Build spatial weight matrix from Delaunay triangulation."""
14
+ from scipy.sparse import lil_matrix
15
+
16
+ tri = Delaunay(coords)
17
+ n = len(coords)
18
+ W = lil_matrix((n, n), dtype=float)
19
+
20
+ # Extract edges from simplices
21
+ edges = set()
22
+ for simplex in tri.simplices:
23
+ for i in range(3):
24
+ for j in range(i + 1, 3):
25
+ edge = (min(simplex[i], simplex[j]), max(simplex[i], simplex[j]))
26
+ edges.add(edge)
27
+
28
+ for i, j in edges:
29
+ W[i, j] = 1.0
30
+ W[j, i] = 1.0
31
+
32
+ W = csc_matrix(W)
33
+ # Normalize: W[i,j] = 1/degree for neighbors
34
+ row_sums = np.array(W.sum(axis=1)).flatten()
35
+ row_sums[row_sums == 0] = 1
36
+ # Symmetric normalization
37
+ W = W + W.T
38
+ W.data = np.ones_like(W.data) # binary adjacency
39
+ return csc_matrix(W)
40
+
41
+
42
+ def geary_c_score(expression: np.ndarray, coords: np.ndarray,
43
+ pvalue_threshold: float = 0.01) -> dict:
44
+ """Compute Geary's C score for each gene.
45
+
46
+ Parameters
47
+ ----------
48
+ expression : np.ndarray
49
+ Expression matrix (cells x genes).
50
+ coords : np.ndarray
51
+ Spatial coordinates (cells x 2).
52
+ pvalue_threshold : float
53
+ FDR-adjusted p-value threshold for gene selection.
54
+
55
+ Returns
56
+ -------
57
+ dict
58
+ Values (1/C), Selected_genes, P_values.
59
+ """
60
+ n_cells, n_genes = expression.shape
61
+
62
+ # Build Delaunay-based spatial weight matrix
63
+ W = _build_spatial_weight_matrix(coords)
64
+ W_dense = W.toarray()
65
+
66
+ S_0 = W_dense.sum()
67
+ S_1 = 0.5 * np.sum((W_dense + W_dense.T) ** 2)
68
+ x_sum = W_dense.sum(axis=0)
69
+ y_sum = W_dense.sum(axis=1)
70
+ S_2 = np.sum((x_sum + y_sum) ** 2)
71
+ N = n_cells
72
+ W_sum = W_dense.sum()
73
+
74
+ geary_values = np.zeros(n_genes)
75
+ p_values = np.zeros(n_genes)
76
+
77
+ for i in range(n_genes):
78
+ X = expression[:, i].reshape(-1, 1)
79
+ X_sq = X ** 2
80
+
81
+ product_1 = 2 * np.sum(W_dense @ X_sq)
82
+ product_2 = 2 * float((X.T @ W_dense @ X).item())
83
+ var_X = np.var(X, ddof=1) * N
84
+
85
+ if var_X == 0:
86
+ geary_values[i] = 1.0
87
+ p_values[i] = 1.0
88
+ continue
89
+
90
+ geary_c = (N - 1) * (product_1 - product_2) / (2 * var_X * W_sum)
91
+ geary_values[i] = geary_c
92
+
93
+ # Kurtosis for variance under randomness
94
+ b_2 = scipy_kurtosis(expression[:, i], fisher=False)
95
+
96
+ var_random = (
97
+ (N - 1) * S_1 * (N**2 - 3*N + 3 - (N-1)*b_2)
98
+ - 0.25 * (N - 1) * S_2 * (N**2 + 3*N - 6 - (N**2 - N + 2)*b_2)
99
+ + S_0**2 * (N**2 - 3 - (N-1)**2 * b_2)
100
+ ) / (N * (N-2) * (N-3) * S_0**2)
101
+
102
+ p_values[i] = norm.cdf(geary_c, loc=1, scale=np.sqrt(max(var_random, 1e-300)))
103
+
104
+ # Invert C (higher = more spatial)
105
+ inv_geary = 1.0 / geary_values
106
+ inv_geary[~np.isfinite(inv_geary)] = 0
107
+
108
+ # FDR correction (Benjamini-Hochberg)
109
+ from statsmodels.stats.multitest import multipletests
110
+ try:
111
+ _, adj_p, _, _ = multipletests(p_values, method="fdr_bh")
112
+ except ImportError:
113
+ # Simple BH fallback
114
+ order = np.argsort(p_values)
115
+ ranked = np.arange(1, len(p_values) + 1)
116
+ adj_p = np.zeros_like(p_values)
117
+ adj_p[order] = p_values[order] * len(p_values) / ranked
118
+ adj_p = np.minimum(adj_p, 1.0)
119
+ # Enforce monotonicity
120
+ for i in range(len(adj_p) - 2, -1, -1):
121
+ adj_p[order[i]] = min(adj_p[order[i]], adj_p[order[i + 1]])
122
+
123
+ # Sort by geary score descending
124
+ sort_idx = np.argsort(-inv_geary)
125
+ inv_geary_sorted = inv_geary[sort_idx]
126
+ adj_p_sorted = adj_p[sort_idx]
127
+
128
+ selected_mask = adj_p_sorted <= pvalue_threshold
129
+ selected_genes = np.where(selected_mask)[0].tolist()
130
+
131
+ # Clip extreme p-values
132
+ adj_p_clipped = adj_p.copy()
133
+ adj_p_clipped[adj_p_clipped < 1e-300] = 1e-300
134
+
135
+ return {
136
+ "Values": inv_geary_sorted,
137
+ "Selected_genes": selected_genes,
138
+ "P_values": adj_p_clipped,
139
+ "sort_idx": sort_idx
140
+ }
141
+
142
+
143
+ def excess_variance_ratio_nb(expression: np.ndarray,
144
+ p_value_threshold: float = 0.01,
145
+ ratio_threshold: float = 1.5) -> dict:
146
+ """Excess variance ratio using Negative Binomial model.
147
+
148
+ Parameters
149
+ ----------
150
+ expression : np.ndarray
151
+ Expression matrix (cells x genes).
152
+ p_value_threshold : float
153
+ P-value threshold for significance.
154
+ ratio_threshold : float
155
+ Variance ratio threshold.
156
+
157
+ Returns
158
+ -------
159
+ dict
160
+ Selected_genes, Excess_variance_ratio.
161
+ """
162
+ if issparse(expression):
163
+ mean_expr = np.array(expression.mean(axis=0)).flatten()
164
+ var_gene = np.array([
165
+ np.var(expression[:, j].toarray().flatten(), ddof=1)
166
+ if issparse(expression) else np.var(expression[:, j], ddof=1)
167
+ for j in range(expression.shape[1])
168
+ ])
169
+ else:
170
+ mean_expr = np.mean(expression, axis=0)
171
+ var_gene = np.var(expression, axis=0, ddof=1)
172
+
173
+ mean_sq = mean_expr ** 2
174
+
175
+ # Fit NB null model: var ~ mu + mu^2 (forced through origin)
176
+ # var = s * mu^2 + mu => var - mu = s * mu^2
177
+ valid = mean_sq > 0
178
+ y_fit = var_gene[valid] - mean_expr[valid]
179
+ X_fit = mean_sq[valid].reshape(-1, 1)
180
+
181
+ # Simple OLS: y = s * X
182
+ s_est = float(np.sum(y_fit * X_fit.flatten()) / np.sum(X_fit.flatten() ** 2))
183
+ s_est = max(s_est, 1e-10)
184
+
185
+ fitted_var = mean_expr + s_est * mean_sq
186
+ excess_ratio = var_gene / fitted_var
187
+ excess_ratio[~np.isfinite(excess_ratio)] = 0
188
+
189
+ # Monte Carlo significance
190
+ n_cells = expression.shape[0]
191
+ n_sim = 500
192
+ mu_values = np.quantile(mean_expr, np.linspace(0, 1, 30))
193
+
194
+ # Estimate variance-of-variance at each mu
195
+ table_var = np.zeros(len(mu_values))
196
+ for idx, mu in enumerate(mu_values):
197
+ sims_var = np.zeros(n_sim)
198
+ for s in range(n_sim):
199
+ # NB: size = 1/s, mu = mu
200
+ samples = np.random.negative_binomial(1 / s_est, 1 / s_est / (mu + 1 / s_est), size=n_cells)
201
+ sims_var[s] = np.var(samples, ddof=1)
202
+ table_var[idx] = np.var(sims_var)
203
+
204
+ # Interpolate
205
+ log_mu = np.log10(mu_values + 1e-300)
206
+ log_var = np.log10(table_var + 1e-300)
207
+ est_var = 10 ** np.interp(np.log10(mean_expr + 1e-300), log_mu, log_var)
208
+
209
+ p_vals = 1 - norm.cdf(var_gene, loc=fitted_var, scale=np.sqrt(est_var))
210
+
211
+ # BH FDR
212
+ from statsmodels.stats.multitest import multipletests
213
+ try:
214
+ _, adj_p, _, _ = multipletests(p_vals, method="fdr_bh")
215
+ except ImportError:
216
+ order = np.argsort(p_vals)
217
+ ranked = np.arange(1, len(p_vals) + 1)
218
+ adj_p = np.zeros_like(p_vals)
219
+ adj_p[order] = p_vals[order] * len(p_vals) / ranked
220
+ adj_p = np.minimum(adj_p, 1.0)
221
+
222
+ selected = np.where((adj_p < p_value_threshold) & (excess_ratio > ratio_threshold))[0]
223
+
224
+ return {
225
+ "Selected_genes": selected.tolist(),
226
+ "Excess_variance_ratio": excess_ratio
227
+ }
228
+
229
+
230
+ def excess_zero_score_nb(expression: np.ndarray,
231
+ p_value_threshold: float = 0.01,
232
+ delta_threshold: float = 0.01) -> dict:
233
+ """Excess zero score using Negative Binomial model.
234
+
235
+ Parameters
236
+ ----------
237
+ expression : np.ndarray
238
+ Expression matrix (cells x genes).
239
+ p_value_threshold : float
240
+ P-value threshold.
241
+ delta_threshold : float
242
+ Excess zero score threshold.
243
+
244
+ Returns
245
+ -------
246
+ dict
247
+ Selected_genes, Excess_zero_score.
248
+ """
249
+ n_cells = expression.shape[0]
250
+
251
+ if issparse(expression):
252
+ mean_expr = np.array(expression.mean(axis=0)).flatten()
253
+ total_expr = np.array(expression.sum(axis=0)).flatten()
254
+ prop_zero = np.array((expression == 0).sum(axis=0)).flatten() / n_cells
255
+ # Variance via colvars_sparse
256
+ from .sparse_utils import colvars_sparse
257
+ var_expr = colvars_sparse(expression)
258
+ else:
259
+ mean_expr = np.mean(expression, axis=0)
260
+ total_expr = np.sum(expression, axis=0)
261
+ prop_zero = np.mean(expression == 0, axis=0)
262
+ var_expr = np.var(expression, axis=0, ddof=1)
263
+
264
+ # Fit NB model for expected zero proportion
265
+ # theta_init = median(1/((var-mean)/mean^2))
266
+ with np.errstate(divide="ignore", invalid="ignore"):
267
+ theta_raw = (var_expr - mean_expr) / mean_expr ** 2
268
+ theta_raw = theta_raw[theta_raw > 0]
269
+ theta_init = np.median(theta_raw) if len(theta_raw) > 0 else 1.0
270
+
271
+ # Fit: log(prop_zero) ~ theta*log(theta) - theta*log(theta+mean)
272
+ valid = (prop_zero > 0) & (prop_zero < 1) & (mean_expr > 0)
273
+ if np.sum(valid) < 3:
274
+ return {"Selected_genes": [], "Excess_zero_score": np.zeros(expression.shape[1])}
275
+
276
+ log_prop = np.log(prop_zero[valid])
277
+ mu_v = mean_expr[valid]
278
+
279
+ def nb_zero_model(mu, theta):
280
+ return theta * np.log(theta) - theta * np.log(theta + mu)
281
+
282
+ try:
283
+ popt, _ = curve_fit(nb_zero_model, mu_v, log_prop, p0=[theta_init],
284
+ bounds=([1e-10], [np.inf]), method="trf")
285
+ theta = popt[0]
286
+ except Exception:
287
+ theta = theta_init
288
+
289
+ expected_prop_zero = (theta / (theta + mean_expr)) ** theta
290
+ delta = prop_zero - expected_prop_zero
291
+
292
+ # Monte Carlo for variance estimation
293
+ n_sim = 500
294
+ mu_values = np.quantile(mean_expr, np.linspace(0, 1, 30))
295
+
296
+ table_var = np.zeros(len(mu_values))
297
+ for idx, mu in enumerate(mu_values):
298
+ sims_pz = np.zeros(n_sim)
299
+ for s in range(n_sim):
300
+ samples = np.random.negative_binomial(1 / theta, 1 / theta / (mu + 1 / theta), size=n_cells)
301
+ sims_pz[s] = np.mean(samples == 0)
302
+ table_var[idx] = np.var(sims_pz)
303
+
304
+ log_mu = np.log10(mu_values + 1e-300)
305
+ log_var = np.log10(table_var + 1e-300)
306
+ est_var = 10 ** np.interp(np.log10(mean_expr + 1e-300), log_mu, log_var)
307
+
308
+ p_vals = 1 - norm.cdf(prop_zero, loc=expected_prop_zero, scale=np.sqrt(est_var))
309
+
310
+ from statsmodels.stats.multitest import multipletests
311
+ try:
312
+ _, adj_p, _, _ = multipletests(p_vals, method="fdr_bh")
313
+ except ImportError:
314
+ order = np.argsort(p_vals)
315
+ ranked = np.arange(1, len(p_vals) + 1)
316
+ adj_p = np.zeros_like(p_vals)
317
+ adj_p[order] = p_vals[order] * len(p_vals) / ranked
318
+ adj_p = np.minimum(adj_p, 1.0)
319
+
320
+ selected = np.where((adj_p < p_value_threshold) & (delta > delta_threshold))[0]
321
+
322
+ return {
323
+ "Selected_genes": selected.tolist(),
324
+ "Excess_zero_score": delta
325
+ }