lotsofcells 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lotsofcells/entropy.py ADDED
@@ -0,0 +1,354 @@
1
+ """Symmetric divergence (KL-based) entropy score, plus the 1-class abundance test."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Optional, Sequence
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from ._stats import (
10
+ _ensure_cols,
11
+ _ensure_rows,
12
+ _table,
13
+ geom_mean,
14
+ pseudo_count_arcsin,
15
+ )
16
+ from ._utils import get_metadata
17
+
18
+
19
+ def _proportions_arcsin(
20
+ tab: pd.DataFrame, label_order: Sequence[str], indexes: Sequence[str]
21
+ ) -> np.ndarray:
22
+ """Per-group proportions across covariables (each row sums to 1).
23
+
24
+ Mirrors the R `entropyScore` normalisation. Note: in R the *random*
25
+ contig table is built from `data.frame(covariable, groups)` (covariable
26
+ first) so `table()` produces shape (ncov, ngroups) and the code applies
27
+ `apply(., 2, row/sum(row))` followed by `t()` — which is mathematically
28
+ equivalent to row-normalising on a (ngroups, ncov) matrix. Since
29
+ `pd.crosstab(groups, covariable)` already returns (ngroups, ncov) here,
30
+ a single function works for both observed and random tables.
31
+ """
32
+ tab = _ensure_rows(tab, label_order)
33
+ tab = _ensure_cols(tab, indexes)
34
+ vals = pseudo_count_arcsin(tab.values.astype(float))
35
+ row_sums = vals.sum(axis=1, keepdims=True)
36
+ return vals / row_sums
37
+
38
+
39
+ def _distance_surprise(p: np.ndarray, q: np.ndarray) -> float:
40
+ return geom_mean(np.abs(p * np.log2(p / q))) + geom_mean(np.abs(q * np.log2(q / p)))
41
+
42
+
43
+ def entropy_score(
44
+ sc_object,
45
+ main_variable: str,
46
+ subtype_variable: str,
47
+ label_order: Sequence[str],
48
+ sample_id: Optional[str] = None,
49
+ permutations: int = 1000,
50
+ seed: Optional[int] = None,
51
+ table: Optional[str] = None,
52
+ plot: bool = True,
53
+ verbose: bool = True,
54
+ pdf_file: Optional[str] = None,
55
+ ):
56
+ """Symmetric divergence score for global proportion dysregulation between 2 groups.
57
+
58
+ Returns a `pandas.Series` with per-covariable relative entropies plus the
59
+ summary fields (``entropy_score``, ``p.val``, ``mean.random.entropy``,
60
+ ``sd.random.entropy``).
61
+
62
+ If ``len(label_order) == 1``, runs the 1-class permutation test on
63
+ ``sample_id`` (analogue of the R `oneClassTest`) and returns a small
64
+ summary dict instead.
65
+ """
66
+ metadata = get_metadata(sc_object, table=table)
67
+
68
+ main_vals = metadata[main_variable].astype(str).to_numpy()
69
+ if not all(l in np.unique(main_vals) for l in label_order):
70
+ missing = [l for l in label_order if l not in np.unique(main_vals)]
71
+ raise ValueError(f"Some groups in label_order not in data: {missing}")
72
+
73
+ metadata = metadata.loc[np.isin(main_vals, list(label_order))].copy()
74
+ groups = metadata[main_variable].astype(str).to_numpy()
75
+ covariable = metadata[subtype_variable].astype(str).to_numpy()
76
+ rng = np.random.default_rng(seed)
77
+
78
+ if len(label_order) == 0:
79
+ raise ValueError("label_order must be specified.")
80
+
81
+ if len(label_order) == 1:
82
+ if sample_id is None:
83
+ raise ValueError("In 1-class mode you must specify `sample_id`.")
84
+ return _one_class_test(
85
+ metadata,
86
+ sample_id,
87
+ covariable,
88
+ permutations,
89
+ rng,
90
+ plot=plot,
91
+ verbose=verbose,
92
+ pdf_file=pdf_file,
93
+ )
94
+
95
+ if len(label_order) > 2:
96
+ raise ValueError(
97
+ f"Only 2 labels are allowed for entropy estimation, got "
98
+ f"{len(label_order)}: {label_order}"
99
+ )
100
+
101
+ if verbose:
102
+ print(
103
+ "Computing entropy proportion over covariables for groups: "
104
+ f"{label_order[0]} vs {label_order[1]}"
105
+ )
106
+ obs_tab = _table(groups, covariable)
107
+ indexes = list(obs_tab.columns)
108
+ contig = _proportions_arcsin(obs_tab, label_order, indexes)
109
+
110
+ # Per-covariable relative entropies (matches R apply over rows... in the R it's
111
+ # apply(contig_tab, 1, function(x) abs(log2((x[1]*log2(x[2]))/(x[1]*log2(x[1])))));
112
+ # since R contig_tab is rows=labels, columns=covariables, apply over rows iterates
113
+ # COLUMNS — so we replicate by iterating columns here)
114
+ rel_entropies = np.empty(len(indexes))
115
+ for j in range(len(indexes)):
116
+ x = contig[:, j]
117
+ with np.errstate(divide="ignore", invalid="ignore"):
118
+ rel_entropies[j] = np.abs(
119
+ np.log2((x[0] * np.log2(x[1])) / (x[0] * np.log2(x[0])))
120
+ )
121
+
122
+ obs_score = _distance_surprise(contig[0], contig[1])
123
+
124
+ # Build cell-crowd for null sampling
125
+ if sample_id is not None:
126
+ samples = metadata[sample_id].astype(str).to_numpy()
127
+ n_per_sample = (
128
+ pd.crosstab(pd.Series(groups), pd.Series(samples)).reindex(label_order)
129
+ )
130
+ n_per_sample = np.sqrt(n_per_sample)
131
+ cell_crowd = {}
132
+ for cond in label_order:
133
+ row = n_per_sample.loc[cond]
134
+ cell_crowd[cond] = list(row[row != 0].astype(int).to_numpy())
135
+ else:
136
+ counts = pd.Series(groups).value_counts().to_dict()
137
+ cell_crowd = {l: int(round(np.sqrt(counts.get(l, 0)))) for l in label_order}
138
+
139
+ if verbose:
140
+ print(f"Starting Monte-Carlo simulation with n. permutations: {permutations}")
141
+
142
+ null_scores = np.empty(permutations)
143
+ for i in range(permutations):
144
+ pieces_cov, pieces_grp = [], []
145
+ for label in label_order:
146
+ crowd = cell_crowd[label]
147
+ if isinstance(crowd, list):
148
+ for n in crowd:
149
+ s = rng.choice(covariable, size=int(n), replace=True)
150
+ pieces_cov.append(s)
151
+ pieces_grp.append(np.repeat(label, len(s)))
152
+ else:
153
+ s = rng.choice(covariable, size=int(crowd), replace=True)
154
+ pieces_cov.append(s)
155
+ pieces_grp.append(np.repeat(label, len(s)))
156
+ cov = np.concatenate(pieces_cov)
157
+ grp = np.concatenate(pieces_grp)
158
+ rand_tab = _table(grp, cov)
159
+ p = _proportions_arcsin(rand_tab, label_order, indexes)
160
+ null_scores[i] = _distance_surprise(p[0], p[1])
161
+
162
+ p_val = float((null_scores >= obs_score).sum() / permutations)
163
+
164
+ if plot:
165
+ try:
166
+ _plot_entropy(
167
+ contig=contig,
168
+ indexes=indexes,
169
+ label_order=label_order,
170
+ obs_score=obs_score,
171
+ null_scores=null_scores,
172
+ p_val=p_val,
173
+ subtype_variable=subtype_variable,
174
+ pdf_file=pdf_file,
175
+ )
176
+ except Exception as e: # noqa: BLE001
177
+ if verbose:
178
+ print(f"(Plot skipped: {e})")
179
+
180
+ out = pd.Series(rel_entropies, index=indexes)
181
+ out["entropy_score"] = obs_score
182
+ out["p.val"] = p_val
183
+ out["mean.random.entropy"] = float(null_scores.mean())
184
+ out["sd.random.entropy"] = float(null_scores.std(ddof=1))
185
+ return out
186
+
187
+
188
+ def _plot_entropy(
189
+ contig, indexes, label_order, obs_score, null_scores, p_val,
190
+ subtype_variable, pdf_file=None,
191
+ ):
192
+ import matplotlib.pyplot as plt
193
+ from ._utils import save_to_pdf
194
+
195
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5), gridspec_kw={"width_ratios": [3, 1]})
196
+ ax = axes[0]
197
+ n = len(indexes)
198
+ width = 0.35
199
+ x = np.arange(n)
200
+ palette = ["#9ECAE1", "#3182BD"]
201
+ for i, label in enumerate(label_order):
202
+ ax.bar(x + (i - 0.5) * width, contig[i], width, label=label, color=palette[i])
203
+ ax.set_xticks(x)
204
+ ax.set_xticklabels(indexes, rotation=45, ha="right")
205
+ ax.set_ylabel("proportion")
206
+ ax.set_title(
207
+ f"Symmetric Divergence Score: {obs_score:.3f} | p.val.adj: {p_val:.3f}"
208
+ )
209
+ ax.legend(title=f"Class: {subtype_variable}")
210
+
211
+ ax2 = axes[1]
212
+ rng = np.random.default_rng(0)
213
+ jitter = rng.uniform(-0.1, 0.1, size=len(null_scores))
214
+ ax2.scatter(jitter, null_scores, color="#D5BADB", alpha=0.5, s=15)
215
+ ax2.axhline(np.median(null_scores), color="#86608E", lw=1)
216
+ ax2.scatter([0], [obs_score], color="#F08080", s=80, zorder=5)
217
+ ax2.set_xticks([])
218
+ ax2.set_ylabel("symmetric divergence")
219
+ plt.tight_layout()
220
+ save_to_pdf(fig, pdf_file)
221
+
222
+
223
+ def _one_class_test(
224
+ metadata,
225
+ sample_id,
226
+ covariable,
227
+ permutations,
228
+ rng,
229
+ plot=True,
230
+ verbose=True,
231
+ pdf_file=None,
232
+ ):
233
+ """Permutation test for sample-level proportion variation in a single condition.
234
+
235
+ Departs from R's `oneClassTest` in one important way: the null draws each
236
+ sample's cells from THAT SAMPLE'S own covariable distribution, not from
237
+ the global pool. The R version sampled every cell from the global pool,
238
+ which collapses both random pseudo-groups onto the same global
239
+ distribution and produces a null that is essentially zero — so the user
240
+ never observes any spread no matter how heterogeneous the real samples
241
+ are. Drawing from per-sample pools preserves real per-sample structure
242
+ and lets random partitions of those samples yield a null distribution
243
+ whose spread reflects across-sample heterogeneity, which is what this
244
+ test is meant to assess.
245
+ """
246
+ samples = metadata[sample_id].astype(str).to_numpy()
247
+ obs_tab = _table(samples, covariable)
248
+ indexes = list(obs_tab.columns)
249
+ n_per_sample = pd.Series(samples).value_counts()
250
+ sqrt_n = np.sqrt(n_per_sample)
251
+ sqrt_n[sqrt_n == 0] = 10
252
+ cell_crowd = sqrt_n.to_dict()
253
+
254
+ # Build a per-sample pool of covariable values (preserves the real cell
255
+ # composition of each sample for the null draw).
256
+ sample_pools = {
257
+ s: covariable[samples == s] for s in n_per_sample.index
258
+ }
259
+ unique_samples = list(n_per_sample.index)
260
+ if len(unique_samples) < 2:
261
+ raise ValueError(
262
+ "1-class entropy test needs at least 2 samples in `sample_id`."
263
+ )
264
+ n_g1 = max(1, round(len(unique_samples) / 2))
265
+
266
+ # Mirror R's iteration count: seq(100) * seq(permutations/10) = 10*perms.
267
+ n_iter = max(int(permutations) * 10, 100)
268
+ null_scores = np.empty(n_iter)
269
+ if verbose:
270
+ print(f"Starting 1-class Monte-Carlo simulation: {n_iter} iterations")
271
+
272
+ for i in range(n_iter):
273
+ perm = rng.permutation(len(unique_samples))
274
+ g1 = [unique_samples[k] for k in perm[:n_g1]]
275
+ g2 = [unique_samples[k] for k in perm[n_g1:]]
276
+ pieces_cov, pieces_grp = [], []
277
+ for s in g1:
278
+ n = max(int(cell_crowd[s]), 1)
279
+ pool = sample_pools[s]
280
+ if len(pool) == 0:
281
+ continue
282
+ draw = rng.choice(pool, size=n, replace=True)
283
+ pieces_cov.append(draw)
284
+ pieces_grp.append(np.repeat("group1", n))
285
+ for s in g2:
286
+ n = max(int(cell_crowd[s]), 1)
287
+ pool = sample_pools[s]
288
+ if len(pool) == 0:
289
+ continue
290
+ draw = rng.choice(pool, size=n, replace=True)
291
+ pieces_cov.append(draw)
292
+ pieces_grp.append(np.repeat("group2", n))
293
+ cov = np.concatenate(pieces_cov)
294
+ grp = np.concatenate(pieces_grp)
295
+ rand_tab = _table(grp, cov)
296
+ p = _proportions_arcsin(rand_tab, ["group1", "group2"], indexes)
297
+ null_scores[i] = _distance_surprise(p[0], p[1])
298
+
299
+ mean_null = float(null_scores.mean())
300
+ sd_null = float(null_scores.std(ddof=1))
301
+ median_null = float(np.median(null_scores))
302
+ cv = float(sd_null / mean_null * 100) if mean_null > 0 else float("inf")
303
+ if median_null > 0:
304
+ relative_iqr = float(
305
+ (np.percentile(null_scores, 75) - np.percentile(null_scores, 25))
306
+ / median_null
307
+ )
308
+ else:
309
+ relative_iqr = float("nan")
310
+ if cv <= 35:
311
+ variation = "Low"
312
+ elif cv <= 50:
313
+ variation = "Medium"
314
+ else:
315
+ variation = "High"
316
+
317
+ if verbose:
318
+ print(f"Coefficient of Variation: {cv:.2f} %")
319
+ print(f"Variation across samples is considered: {variation}")
320
+ print(f"Relative IQR: {relative_iqr:.3f}")
321
+
322
+ if plot:
323
+ try:
324
+ import matplotlib.pyplot as plt
325
+ from ._utils import save_to_pdf
326
+
327
+ fig, ax = plt.subplots(figsize=(3.5, 5))
328
+ jitter = rng.uniform(-0.1, 0.1, size=len(null_scores))
329
+ ax.scatter(jitter, null_scores, color="#D5BADB", alpha=0.5, s=15)
330
+ ax.axhline(median_null, color="#86608E", lw=1)
331
+ ax.set_xlim(-0.5, 0.5)
332
+ lo = float(min(0.0, null_scores.min()))
333
+ hi = float(null_scores.max())
334
+ pad = max(1e-3, 0.1 * (hi - lo))
335
+ ax.set_ylim(lo, hi + pad)
336
+ ax.set_xticks([])
337
+ ax.set_ylabel("symmetric divergence (null)")
338
+ ax.set_title(
339
+ f"1-class null distribution\n"
340
+ f"median={median_null:.4f} CV={cv:.1f}% ({variation})"
341
+ )
342
+ plt.tight_layout()
343
+ save_to_pdf(fig, pdf_file)
344
+ except Exception:
345
+ pass
346
+
347
+ return {
348
+ "cv": cv,
349
+ "variation": variation,
350
+ "relative_iqr": relative_iqr,
351
+ "mean.random.entropy": mean_null,
352
+ "sd.random.entropy": sd_null,
353
+ "median.random.entropy": median_null,
354
+ }