sclab 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sclab might be problematic. Click here for more details.
- sclab/__init__.py +1 -1
- sclab/_sclab.py +7 -3
- sclab/dataset/_dataset.py +1 -1
- sclab/dataset/processor/_processor.py +19 -4
- sclab/examples/processor_steps/__init__.py +2 -0
- sclab/examples/processor_steps/_doublet_detection.py +68 -0
- sclab/examples/processor_steps/_integration.py +47 -20
- sclab/examples/processor_steps/_neighbors.py +24 -4
- sclab/examples/processor_steps/_pca.py +11 -6
- sclab/examples/processor_steps/_preprocess.py +14 -1
- sclab/examples/processor_steps/_qc.py +22 -6
- sclab/gui/__init__.py +0 -0
- sclab/gui/components/__init__.py +7 -0
- sclab/gui/components/_guided_pseudotime.py +482 -0
- sclab/gui/components/_transfer_metadata.py +186 -0
- sclab/methods/__init__.py +16 -0
- sclab/preprocess/__init__.py +19 -0
- sclab/preprocess/_cca.py +154 -0
- sclab/preprocess/_cca_integrate.py +109 -0
- sclab/preprocess/_filter_obs.py +42 -0
- sclab/preprocess/_harmony.py +421 -0
- sclab/preprocess/_harmony_integrate.py +53 -0
- sclab/preprocess/_normalize_weighted.py +61 -0
- sclab/preprocess/_subset.py +208 -0
- sclab/preprocess/_transfer_metadata.py +137 -0
- sclab/preprocess/_transform.py +82 -0
- sclab/preprocess/_utils.py +96 -0
- sclab/tools/__init__.py +0 -0
- sclab/tools/cellflow/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
- sclab/tools/cellflow/pseudotime/__init__.py +0 -0
- sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
- sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
- sclab/tools/cellflow/utils/__init__.py +0 -0
- sclab/tools/cellflow/utils/density_nd.py +215 -0
- sclab/tools/cellflow/utils/interpolate.py +334 -0
- sclab/tools/cellflow/utils/smoothen.py +124 -0
- sclab/tools/cellflow/utils/times.py +55 -0
- sclab/tools/differential_expression/__init__.py +5 -0
- sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
- sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
- sclab/tools/doublet_detection/__init__.py +5 -0
- sclab/tools/doublet_detection/_scrublet.py +64 -0
- sclab/tools/labeling/__init__.py +6 -0
- sclab/tools/labeling/sctype.py +233 -0
- sclab/utils/__init__.py +5 -0
- sclab/utils/_write_excel.py +510 -0
- {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/METADATA +6 -2
- sclab-0.3.1.dist-info/RECORD +82 -0
- sclab-0.2.5.dist-info/RECORD +0 -45
- {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/WHEEL +0 -0
- {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
from anndata import AnnData
|
|
7
|
+
from numpy import floating
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from scipy.integrate import cumulative_trapezoid
|
|
10
|
+
from scipy.interpolate import BSpline, interp1d
|
|
11
|
+
from scipy.signal import find_peaks
|
|
12
|
+
|
|
13
|
+
from ..utils.density_nd import fit_density_1d
|
|
14
|
+
from ..utils.times import guess_trange
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def density(
|
|
20
|
+
adata: AnnData,
|
|
21
|
+
time_key: str = "pseudotime",
|
|
22
|
+
t_range: tuple[float, float] | None = None,
|
|
23
|
+
periodic: bool | None = None,
|
|
24
|
+
bandwidth: float = 1 / 64,
|
|
25
|
+
algorithm: str = "auto",
|
|
26
|
+
kernel: str = "gaussian",
|
|
27
|
+
metric: str = "euclidean",
|
|
28
|
+
max_grid_size: int = 2**8 + 1,
|
|
29
|
+
plot_density: bool = False,
|
|
30
|
+
plot_density_fit: bool = False,
|
|
31
|
+
plot_density_fit_derivative: bool = False,
|
|
32
|
+
plot_histogram: bool = False,
|
|
33
|
+
histogram_nbins: int = 50,
|
|
34
|
+
):
|
|
35
|
+
if t_range is None and time_key in adata.uns:
|
|
36
|
+
# using stored t_range
|
|
37
|
+
t_range = adata.uns[time_key]["t_range"]
|
|
38
|
+
else:
|
|
39
|
+
# guessing t_range
|
|
40
|
+
pts = adata.obs[time_key].values
|
|
41
|
+
pts = pts[np.isfinite(pts)]
|
|
42
|
+
t_range = guess_trange(pts)
|
|
43
|
+
if pts.size < 500:
|
|
44
|
+
logger.warning(
|
|
45
|
+
"Guessing t_range may not be accurate for fewer than 500 points."
|
|
46
|
+
" Consider setting the pseudotime_t_range parameter instead."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if periodic is None and time_key in adata.uns and "periodic" in adata.uns[time_key]:
|
|
50
|
+
periodic = adata.uns[time_key]["periodic"]
|
|
51
|
+
else:
|
|
52
|
+
periodic = False
|
|
53
|
+
|
|
54
|
+
times = adata.obs[time_key].values
|
|
55
|
+
lam = 1 / max_grid_size / 1e4
|
|
56
|
+
rslt, bspl = fit_density_1d(
|
|
57
|
+
times=times,
|
|
58
|
+
t_range=t_range,
|
|
59
|
+
periodic=periodic,
|
|
60
|
+
bandwidth=bandwidth,
|
|
61
|
+
algorithm=algorithm,
|
|
62
|
+
kernel=kernel,
|
|
63
|
+
metric=metric,
|
|
64
|
+
max_grid_size=max_grid_size,
|
|
65
|
+
lam=lam,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if time_key not in adata.uns:
|
|
69
|
+
adata.uns[time_key] = {
|
|
70
|
+
"t_range": list(t_range),
|
|
71
|
+
"periodic": periodic,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
t, c, k = bspl.tck
|
|
75
|
+
density_bspline_tck = dict(t=t.tolist(), c=c.tolist(), k=k)
|
|
76
|
+
adata.uns[time_key].update(
|
|
77
|
+
{
|
|
78
|
+
"density": {
|
|
79
|
+
"params": {
|
|
80
|
+
"bandwidth": bandwidth,
|
|
81
|
+
"algorithm": algorithm,
|
|
82
|
+
"kernel": kernel,
|
|
83
|
+
"metric": metric,
|
|
84
|
+
"max_grid_size": max_grid_size,
|
|
85
|
+
},
|
|
86
|
+
"density_bspline_tck": density_bspline_tck,
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if plot_density | plot_density_fit | plot_density_fit_derivative | plot_histogram:
|
|
92
|
+
from ..utils.density_nd import density_result_1d
|
|
93
|
+
|
|
94
|
+
density_result_1d(
|
|
95
|
+
rslt,
|
|
96
|
+
data=times[~np.isnan(times)],
|
|
97
|
+
density_fit_lam=lam,
|
|
98
|
+
plot_density=plot_density,
|
|
99
|
+
plot_density_fit=plot_density_fit,
|
|
100
|
+
plot_density_fit_derivative=plot_density_fit_derivative,
|
|
101
|
+
plot_histogram=plot_histogram,
|
|
102
|
+
histogram_nbins=histogram_nbins,
|
|
103
|
+
show=True,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def density_dynamics(
|
|
108
|
+
adata: AnnData,
|
|
109
|
+
time_key: str = "pseudotime",
|
|
110
|
+
t_range: tuple[float, float] | None = None,
|
|
111
|
+
periodic: bool | None = None,
|
|
112
|
+
bandwidth: float = 1 / 64,
|
|
113
|
+
algorithm: str = "auto",
|
|
114
|
+
kernel: str = "gaussian",
|
|
115
|
+
metric: str = "euclidean",
|
|
116
|
+
max_grid_size: int = 2**8 + 1,
|
|
117
|
+
derivative: int = 0,
|
|
118
|
+
mode: Literal["peaks", "valleys"] = "peaks",
|
|
119
|
+
find_peaks_kwargs: dict = {},
|
|
120
|
+
plot_density: bool = False,
|
|
121
|
+
plot_density_fit: bool = False,
|
|
122
|
+
plot_density_fit_derivative: bool = False,
|
|
123
|
+
plot_histogram: bool = False,
|
|
124
|
+
histogram_nbins: int = 50,
|
|
125
|
+
):
|
|
126
|
+
if t_range is None:
|
|
127
|
+
test = time_key in adata.uns and "t_range" in adata.uns[time_key]
|
|
128
|
+
assert test, f"t_range must be provided for time_key: {time_key}"
|
|
129
|
+
t_range = adata.uns[time_key]["t_range"]
|
|
130
|
+
|
|
131
|
+
if periodic is None:
|
|
132
|
+
if time_key in adata.uns and "periodic" in adata.uns[time_key]:
|
|
133
|
+
periodic = adata.uns[time_key]["periodic"]
|
|
134
|
+
else:
|
|
135
|
+
periodic = False
|
|
136
|
+
|
|
137
|
+
times = adata.obs[time_key].values
|
|
138
|
+
lam = 1 / max_grid_size / 1e4
|
|
139
|
+
rslt, bspl = fit_density_1d(
|
|
140
|
+
times=times,
|
|
141
|
+
t_range=t_range,
|
|
142
|
+
periodic=periodic,
|
|
143
|
+
bandwidth=bandwidth,
|
|
144
|
+
algorithm=algorithm,
|
|
145
|
+
kernel=kernel,
|
|
146
|
+
metric=metric,
|
|
147
|
+
max_grid_size=max_grid_size,
|
|
148
|
+
lam=lam,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
t = np.linspace(*t_range, 2**16 + 1)
|
|
152
|
+
y = bspl.derivative(derivative)(t)
|
|
153
|
+
if mode == "peaks":
|
|
154
|
+
pass
|
|
155
|
+
elif mode == "valleys":
|
|
156
|
+
y = -y
|
|
157
|
+
|
|
158
|
+
tmin, tmax = t_range
|
|
159
|
+
tspan = tmax - tmin
|
|
160
|
+
if periodic:
|
|
161
|
+
tt = np.concatenate([t[:-1] + i * tspan for i in range(3)]) - tspan
|
|
162
|
+
yy = np.tile(y[:-1], 3)
|
|
163
|
+
else:
|
|
164
|
+
tt = t
|
|
165
|
+
yy = y
|
|
166
|
+
|
|
167
|
+
peak_height = find_peaks_kwargs.pop("height", 0.0)
|
|
168
|
+
peak_height = peak_height * y.max()
|
|
169
|
+
|
|
170
|
+
peaks, _ = find_peaks(yy, height=peak_height, **find_peaks_kwargs)
|
|
171
|
+
peak_times = tt[peaks]
|
|
172
|
+
peak_heights = yy[peaks]
|
|
173
|
+
|
|
174
|
+
peaks_mask = np.logical_and(peak_times >= tmin, peak_times < tmax)
|
|
175
|
+
peak_times = peak_times[peaks_mask]
|
|
176
|
+
peak_heights = peak_heights[peaks_mask]
|
|
177
|
+
|
|
178
|
+
timepoints = peak_times - tmin
|
|
179
|
+
if periodic:
|
|
180
|
+
deltas = (timepoints - np.roll(timepoints, 1)) % tspan
|
|
181
|
+
else:
|
|
182
|
+
timepoints = np.insert(timepoints, 0, 0)
|
|
183
|
+
deltas = timepoints[1:] - timepoints[:-1]
|
|
184
|
+
|
|
185
|
+
if time_key not in adata.uns:
|
|
186
|
+
adata.uns[time_key] = {}
|
|
187
|
+
|
|
188
|
+
t, c, k = bspl.tck
|
|
189
|
+
density_bspline_tck = dict(t=t.tolist(), c=c.tolist(), k=k)
|
|
190
|
+
adata.uns[time_key].update(
|
|
191
|
+
{
|
|
192
|
+
f"density_dynamics_d{derivative}_{mode}": {
|
|
193
|
+
"times": peak_times,
|
|
194
|
+
"deltas": deltas,
|
|
195
|
+
"heights": peak_heights,
|
|
196
|
+
"params": {
|
|
197
|
+
"bandwidth": bandwidth,
|
|
198
|
+
"algorithm": algorithm,
|
|
199
|
+
"kernel": kernel,
|
|
200
|
+
"metric": metric,
|
|
201
|
+
"max_grid_size": max_grid_size,
|
|
202
|
+
"find_peaks_kwargs": {"height": peak_height, **find_peaks_kwargs},
|
|
203
|
+
},
|
|
204
|
+
"density_bspline_tck": density_bspline_tck,
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if plot_density | plot_density_fit | plot_density_fit_derivative | plot_histogram:
|
|
210
|
+
from ..utils.density_nd import density_result_1d
|
|
211
|
+
|
|
212
|
+
ax = density_result_1d(
|
|
213
|
+
rslt,
|
|
214
|
+
data=times[~np.isnan(times)],
|
|
215
|
+
density_fit_lam=lam,
|
|
216
|
+
plot_density=plot_density,
|
|
217
|
+
plot_density_fit=plot_density_fit,
|
|
218
|
+
plot_density_fit_derivative=plot_density_fit_derivative,
|
|
219
|
+
plot_histogram=plot_histogram,
|
|
220
|
+
histogram_nbins=histogram_nbins,
|
|
221
|
+
show=False,
|
|
222
|
+
)
|
|
223
|
+
for t in peak_times:
|
|
224
|
+
ax.axvline(t, color="k", linestyle="--")
|
|
225
|
+
plt.show()
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def real_time(
|
|
229
|
+
adata: AnnData,
|
|
230
|
+
pseudotime_key: str = "pseudotime",
|
|
231
|
+
pseudotime_t_range: tuple[float, float] | None = None,
|
|
232
|
+
periodic: bool | None = None,
|
|
233
|
+
key_added: str = "real_time",
|
|
234
|
+
tmax: float = 100,
|
|
235
|
+
units: Literal["minutes", "hours", "days", "percent"] = "percent",
|
|
236
|
+
bandwidth: float = 1 / 64,
|
|
237
|
+
algorithm: str = "auto",
|
|
238
|
+
kernel: str = "gaussian",
|
|
239
|
+
metric: str = "euclidean",
|
|
240
|
+
max_grid_size: int = 2**8 + 1,
|
|
241
|
+
plot_density: bool = False,
|
|
242
|
+
plot_density_fit: bool = False,
|
|
243
|
+
plot_density_fit_derivative: bool = False,
|
|
244
|
+
plot_histogram: bool = False,
|
|
245
|
+
histogram_nbins: int = 50,
|
|
246
|
+
):
|
|
247
|
+
density(
|
|
248
|
+
adata,
|
|
249
|
+
time_key=pseudotime_key,
|
|
250
|
+
t_range=pseudotime_t_range,
|
|
251
|
+
periodic=periodic,
|
|
252
|
+
bandwidth=bandwidth,
|
|
253
|
+
algorithm=algorithm,
|
|
254
|
+
kernel=kernel,
|
|
255
|
+
metric=metric,
|
|
256
|
+
max_grid_size=max_grid_size,
|
|
257
|
+
plot_density=plot_density,
|
|
258
|
+
plot_density_fit=plot_density_fit,
|
|
259
|
+
plot_density_fit_derivative=plot_density_fit_derivative,
|
|
260
|
+
plot_histogram=plot_histogram,
|
|
261
|
+
histogram_nbins=histogram_nbins,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
time_key_uns = adata.uns[pseudotime_key]
|
|
265
|
+
# density function sets appropriate t_range and periodic parameters if missing
|
|
266
|
+
pseudotime_t_range = time_key_uns["t_range"]
|
|
267
|
+
periodic = time_key_uns["periodic"]
|
|
268
|
+
# density_bspline_tck is computed in density function
|
|
269
|
+
density_bspline_tck = time_key_uns["density"]["density_bspline_tck"]
|
|
270
|
+
|
|
271
|
+
pt_min, pt_tmax = pseudotime_t_range
|
|
272
|
+
pseudotimes = adata.obs[pseudotime_key].values
|
|
273
|
+
pt_mask = (pt_min <= pseudotimes) * (pseudotimes <= pt_tmax)
|
|
274
|
+
pseudotimes = pseudotimes[pt_mask]
|
|
275
|
+
|
|
276
|
+
rt = _area_under_curve(pseudotimes, tmax, density_bspline_tck)
|
|
277
|
+
|
|
278
|
+
adata.obs[key_added] = np.nan
|
|
279
|
+
adata.obs.loc[pt_mask, key_added] = rt
|
|
280
|
+
|
|
281
|
+
adata.uns[key_added] = {
|
|
282
|
+
"params": {
|
|
283
|
+
"pseudotime_key": pseudotime_key,
|
|
284
|
+
"pseudotime_t_range": pseudotime_t_range,
|
|
285
|
+
"tmax": tmax,
|
|
286
|
+
"units": units,
|
|
287
|
+
"bandwidth": bandwidth,
|
|
288
|
+
"algorithm": algorithm,
|
|
289
|
+
"kernel": kernel,
|
|
290
|
+
"metric": metric,
|
|
291
|
+
"max_grid_size": max_grid_size,
|
|
292
|
+
},
|
|
293
|
+
"density_bspline_tck": density_bspline_tck,
|
|
294
|
+
"tmax": tmax,
|
|
295
|
+
"t_range": [0.0, tmax],
|
|
296
|
+
"t_units": units,
|
|
297
|
+
"periodic": periodic,
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _area_under_curve(
|
|
302
|
+
pseudotimes: NDArray[floating], tmax: float, tck_dict: dict[str, list[float] | int]
|
|
303
|
+
):
|
|
304
|
+
bspl = BSpline(**tck_dict)
|
|
305
|
+
|
|
306
|
+
# the normalized flux should be 1 / tmax
|
|
307
|
+
q = 1.0 / tmax
|
|
308
|
+
|
|
309
|
+
# we will use cumulative_trapezoid to calculate the integral
|
|
310
|
+
# we should make sure that we have enough points to get a good approximation
|
|
311
|
+
# we will use 1000 extra points evenly distributed between 0 and 1 to
|
|
312
|
+
# fill the gaps, and make sure to remove them after the calculation
|
|
313
|
+
n = 1000
|
|
314
|
+
x = np.concatenate([pseudotimes, np.linspace(1 / n, 1, n)])
|
|
315
|
+
|
|
316
|
+
# cumulative_trapezoid requires the x values to be sorted
|
|
317
|
+
o = np.argsort(x)
|
|
318
|
+
# we will need to sort the result back to the original order
|
|
319
|
+
oo = np.argsort(o)
|
|
320
|
+
|
|
321
|
+
# we need to insert 0 at the beginning, this defines the starting point
|
|
322
|
+
# of the integral
|
|
323
|
+
x = np.insert(x[o], 0, 0)
|
|
324
|
+
d = bspl(x)
|
|
325
|
+
|
|
326
|
+
return cumulative_trapezoid(d, x)[oo][:-n] / q
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def get_realtimes(
|
|
330
|
+
pseudotimes: NDArray[floating], adata: AnnData, realtime_key: str = "real_time"
|
|
331
|
+
):
|
|
332
|
+
tmax = adata.uns[realtime_key]["tmax"]
|
|
333
|
+
tck_dict = adata.uns[realtime_key]["density_bspline_tck"]
|
|
334
|
+
|
|
335
|
+
return _area_under_curve(pseudotimes, tmax, tck_dict)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def get_pseudotimes(
|
|
339
|
+
realtimes: NDArray[floating], adata: AnnData, realtime_key: str = "real_time"
|
|
340
|
+
):
|
|
341
|
+
tmax = adata.uns[realtime_key]["tmax"]
|
|
342
|
+
tck_dict = adata.uns[realtime_key]["density_bspline_tck"]
|
|
343
|
+
pseudotime_t_range = adata.uns[realtime_key]["params"]["pseudotime_t_range"]
|
|
344
|
+
|
|
345
|
+
x = np.linspace(*pseudotime_t_range, 1001)
|
|
346
|
+
y = _area_under_curve(x, tmax, tck_dict)
|
|
347
|
+
interpolator = interp1d(y, x, kind="cubic")
|
|
348
|
+
|
|
349
|
+
return interpolator(realtimes)
|
|
File without changes
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
from typing import Literal, NamedTuple
|
|
2
|
+
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
from anndata import AnnData
|
|
6
|
+
from numpy import bool_, floating
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from scipy.integrate import cumulative_trapezoid, quad
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from ..utils.density_nd import density_nd
|
|
12
|
+
from ..utils.interpolate import NDBSpline, NDFourier, fit_smoothing_spline
|
|
13
|
+
from .timeseries import periodic_sliding_window
|
|
14
|
+
|
|
15
|
+
_2PI = 2 * np.pi
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PseudotimeResult(NamedTuple):
|
|
19
|
+
pseudotime: NDArray[floating]
|
|
20
|
+
residues: NDArray[floating]
|
|
21
|
+
phi: NDArray[floating]
|
|
22
|
+
F: NDFourier | NDBSpline
|
|
23
|
+
SNR: NDArray[floating]
|
|
24
|
+
snr_mask: NDArray[bool_]
|
|
25
|
+
t_mask: NDArray[bool_]
|
|
26
|
+
fp_resolution: float
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def periodic_parameter(data: NDArray[floating]) -> NDArray[floating]:
|
|
30
|
+
x, y = data.T.astype(float)
|
|
31
|
+
return np.arctan2(y, x) % _2PI
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _pseudotime(
|
|
35
|
+
t: NDArray[floating],
|
|
36
|
+
X: NDArray[floating],
|
|
37
|
+
t_range: tuple[float, float],
|
|
38
|
+
n_dims: int = 10,
|
|
39
|
+
min_snr: float = 0.25,
|
|
40
|
+
periodic: bool = False,
|
|
41
|
+
method: Literal["fourier", "splines"] = "splines",
|
|
42
|
+
largest_harmonic: int = 5,
|
|
43
|
+
roughness: float | None = None,
|
|
44
|
+
progress: bool = True,
|
|
45
|
+
) -> PseudotimeResult:
|
|
46
|
+
if not periodic:
|
|
47
|
+
assert method == "splines"
|
|
48
|
+
|
|
49
|
+
tmin, tmax = t_range
|
|
50
|
+
tspan = tmax - tmin
|
|
51
|
+
|
|
52
|
+
if periodic:
|
|
53
|
+
assert tmin == 0.0
|
|
54
|
+
|
|
55
|
+
match method:
|
|
56
|
+
case "fourier":
|
|
57
|
+
F = NDFourier(t_range=t_range, largest_harmonic=largest_harmonic)
|
|
58
|
+
case "splines":
|
|
59
|
+
F = NDBSpline(t_range=t_range, periodic=periodic, roughness=roughness)
|
|
60
|
+
case _:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f'{method} is not a valid fitting method. Choose one of: "fourier", "splines"'
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
t_mask = (tmin <= t) * (t <= tmax)
|
|
66
|
+
t = t[t_mask]
|
|
67
|
+
X = X[t_mask]
|
|
68
|
+
|
|
69
|
+
if periodic:
|
|
70
|
+
M = periodic_sliding_window(X, t, 50, np.median)
|
|
71
|
+
else:
|
|
72
|
+
M = X
|
|
73
|
+
|
|
74
|
+
# we fit an n-dimensional curve to the data
|
|
75
|
+
F.fit(t, M)
|
|
76
|
+
|
|
77
|
+
# we use the signal-to-noise ratio to assess which dimensions show a strong signal
|
|
78
|
+
# we only keep dimensions with some signal through the initial ordering t
|
|
79
|
+
SNR: NDArray[floating] = F(t).var(axis=0) / X.var(axis=0)
|
|
80
|
+
SNR = SNR / SNR.max()
|
|
81
|
+
snr_mask = SNR > min_snr
|
|
82
|
+
|
|
83
|
+
dim_mask = np.arange(X.shape[1]) < n_dims
|
|
84
|
+
|
|
85
|
+
# we remove noisy dimensions
|
|
86
|
+
X = X[:, snr_mask & dim_mask]
|
|
87
|
+
# `NDFourier` and `NDBSpline` objects can be sliced like so
|
|
88
|
+
full_F = F
|
|
89
|
+
F = F[snr_mask & dim_mask]
|
|
90
|
+
|
|
91
|
+
# we will find the closest points in the curve for each data point in X
|
|
92
|
+
# we do this in stages using euclidean distance
|
|
93
|
+
# after each stage we increase the numeric precision
|
|
94
|
+
n = 100
|
|
95
|
+
m = 10
|
|
96
|
+
k = 10
|
|
97
|
+
|
|
98
|
+
# T is a matrix of timepoints
|
|
99
|
+
# dim 0 has resolution 0.01
|
|
100
|
+
# dim 1 has resolution 0.0001
|
|
101
|
+
T = (
|
|
102
|
+
np.linspace(tmin, tmax, n + 1)[:-1, None]
|
|
103
|
+
+ np.linspace(0, tspan / n, m + 1)[None]
|
|
104
|
+
)
|
|
105
|
+
# evaluate the curve points
|
|
106
|
+
Y = F(T)
|
|
107
|
+
|
|
108
|
+
# for each point, we find which row in T has the closest point to the curve
|
|
109
|
+
closest_order_1 = np.argmin(
|
|
110
|
+
np.linalg.norm(
|
|
111
|
+
X[None] - Y[:, [m // 2]],
|
|
112
|
+
axis=2,
|
|
113
|
+
),
|
|
114
|
+
axis=0,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# for each point, we find which column in T has the closest point to the curve
|
|
118
|
+
closest_order_2 = np.argmin(
|
|
119
|
+
np.linalg.norm(
|
|
120
|
+
X[:, None] - Y[closest_order_1],
|
|
121
|
+
axis=2,
|
|
122
|
+
),
|
|
123
|
+
axis=1,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# we obtain the corresponding pseudotime ordering
|
|
127
|
+
phi = T[closest_order_1, closest_order_2]
|
|
128
|
+
|
|
129
|
+
# so far our pseudotime estimation has resolution 0.0001
|
|
130
|
+
# we can refine it to match the floating point resolution of the data's dtype
|
|
131
|
+
fp_res = np.finfo(X.dtype).resolution
|
|
132
|
+
res = 1 / n / m
|
|
133
|
+
|
|
134
|
+
n_iters = int(np.floor(np.log10(res) - np.log10(fp_res)))
|
|
135
|
+
range_obj = range(n_iters)
|
|
136
|
+
if progress:
|
|
137
|
+
range_obj = tqdm(range_obj, bar_format="{percentage:3.0f}%|{bar}|")
|
|
138
|
+
|
|
139
|
+
for _ in range_obj:
|
|
140
|
+
# we create a new matrix of timepoints
|
|
141
|
+
T = phi[:, None] + np.linspace(-tspan * res / 2, tspan * res / 2, k + 1)
|
|
142
|
+
|
|
143
|
+
# make sure we didn't go over the range
|
|
144
|
+
T = T.clip(*t_range)
|
|
145
|
+
|
|
146
|
+
# and evaluate the curve points
|
|
147
|
+
Y = F(T)
|
|
148
|
+
|
|
149
|
+
# for each point, we find which column in T has the closest point to the curve
|
|
150
|
+
closest_order_3 = np.argmin(
|
|
151
|
+
np.linalg.norm(
|
|
152
|
+
X[:, None] - Y,
|
|
153
|
+
axis=2,
|
|
154
|
+
),
|
|
155
|
+
axis=1,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# we obtain the corresponding pseudotime ordering with the current resolution
|
|
159
|
+
phi = T[np.arange(X.shape[0]), closest_order_3]
|
|
160
|
+
# update the current resolution
|
|
161
|
+
res = res / k
|
|
162
|
+
|
|
163
|
+
# # converts to unit vectors. returns an array of shape (n_points, n_dims)
|
|
164
|
+
# def unit(v):
|
|
165
|
+
# return v / np.linalg.norm(v, axis=-1, keepdims=True)
|
|
166
|
+
|
|
167
|
+
# # cosine of the angle between the vector from the curve to the data point
|
|
168
|
+
# # and the tangent vector to the curve at the closest point
|
|
169
|
+
# def rv_cosine(p):
|
|
170
|
+
# R = unit(X - F(p))
|
|
171
|
+
# V = unit(F(p, d=1))
|
|
172
|
+
# C = (V * R).sum(axis=-1)
|
|
173
|
+
# return C
|
|
174
|
+
|
|
175
|
+
# cosine_mask = np.abs(rv_cosine(phi)) < 0.01
|
|
176
|
+
# t_mask[t_mask] = cosine_mask
|
|
177
|
+
# phi = phi[cosine_mask]
|
|
178
|
+
# X = X[cosine_mask]
|
|
179
|
+
|
|
180
|
+
if periodic:
|
|
181
|
+
phi = phi % tspan
|
|
182
|
+
|
|
183
|
+
residues = np.linalg.norm(X - F(phi), axis=-1)
|
|
184
|
+
|
|
185
|
+
# speed returns an array of shape (n_points,)
|
|
186
|
+
def speed(t):
|
|
187
|
+
return np.linalg.norm(F(t, d=1), axis=-1)
|
|
188
|
+
|
|
189
|
+
# arclen returns a scalar
|
|
190
|
+
def arclen(t):
|
|
191
|
+
return quad(speed, tmin, t, limit=500, epsrel=1.49e-6)[0]
|
|
192
|
+
|
|
193
|
+
# we will use cumulative_trapezoid to calculate the integral
|
|
194
|
+
# we should make sure that we have enough points to get a good approximation
|
|
195
|
+
# we will use 1000 extra points evenly distributed between 0 and 1 to
|
|
196
|
+
# fill in the gaps, and make sure to remove them after the calculation
|
|
197
|
+
n = 1_000
|
|
198
|
+
x = np.concatenate([phi, np.linspace(tmin + 1 / n, tmax, n)])
|
|
199
|
+
|
|
200
|
+
o = np.argsort(x)
|
|
201
|
+
oo = np.argsort(o)
|
|
202
|
+
x = np.insert(x[o], 0, tmin)
|
|
203
|
+
pseudotime: NDArray[floating] = cumulative_trapezoid(speed(x), x=x)[oo][
|
|
204
|
+
:-n
|
|
205
|
+
] / arclen(tmax)
|
|
206
|
+
|
|
207
|
+
return PseudotimeResult(
|
|
208
|
+
pseudotime,
|
|
209
|
+
residues,
|
|
210
|
+
phi,
|
|
211
|
+
full_F,
|
|
212
|
+
SNR,
|
|
213
|
+
snr_mask,
|
|
214
|
+
t_mask,
|
|
215
|
+
fp_res,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def pseudotime(
|
|
220
|
+
adata: AnnData,
|
|
221
|
+
use_rep: str,
|
|
222
|
+
t_key: str,
|
|
223
|
+
t_range: tuple[float, float],
|
|
224
|
+
n_dims: int = 10,
|
|
225
|
+
min_snr: float = 0.25,
|
|
226
|
+
periodic: bool = False,
|
|
227
|
+
method: Literal["fourier", "splines"] = "splines",
|
|
228
|
+
largest_harmonic: int = 5,
|
|
229
|
+
roughness: float | None = None,
|
|
230
|
+
key_added="pseudotime",
|
|
231
|
+
) -> None:
|
|
232
|
+
X = adata.obsm[use_rep].copy().astype(float)
|
|
233
|
+
X_path = np.zeros_like(X)
|
|
234
|
+
X_path_derivative = np.zeros_like(X)
|
|
235
|
+
X_path_derivative_norm = np.zeros((adata.n_obs,))
|
|
236
|
+
|
|
237
|
+
t = adata.obs[t_key].values
|
|
238
|
+
|
|
239
|
+
result = _pseudotime(
|
|
240
|
+
t, X, t_range, n_dims, min_snr, periodic, method, largest_harmonic, roughness
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
t_mask = result.t_mask
|
|
244
|
+
pcs_mask = result.snr_mask
|
|
245
|
+
mask = t_mask[:, None] * pcs_mask
|
|
246
|
+
|
|
247
|
+
X_path[mask] = result.F[pcs_mask](result.phi).flatten()
|
|
248
|
+
X_path_derivative[mask] = result.F[pcs_mask](result.phi, d=1).flatten()
|
|
249
|
+
X_path_derivative_norm[t_mask] = np.linalg.norm(X_path_derivative[t_mask], axis=1)
|
|
250
|
+
|
|
251
|
+
adata.obs[key_added] = np.nan
|
|
252
|
+
adata.obs[key_added + "_path_residue"] = np.nan
|
|
253
|
+
|
|
254
|
+
adata.obs.loc[t_mask, key_added] = result.pseudotime
|
|
255
|
+
adata.obs.loc[t_mask, key_added + "_path_residue"] = result.residues
|
|
256
|
+
# adata.obs[key_added + "_path_derivative_norm"] = X_path_derivative_norm
|
|
257
|
+
adata.obsm[key_added + "_path"] = X_path
|
|
258
|
+
adata.obsm[key_added + "_path_derivative"] = X_path_derivative
|
|
259
|
+
adata.uns[key_added] = {
|
|
260
|
+
"params": {
|
|
261
|
+
"use_rep": use_rep,
|
|
262
|
+
"t_key": t_key,
|
|
263
|
+
"t_range": list(t_range),
|
|
264
|
+
"min_snr": min_snr,
|
|
265
|
+
"periodic": periodic,
|
|
266
|
+
"method": method,
|
|
267
|
+
"largest_harmonic": largest_harmonic,
|
|
268
|
+
"roughness": roughness,
|
|
269
|
+
},
|
|
270
|
+
"snr": result.SNR.tolist(),
|
|
271
|
+
"t_range": [0, 1],
|
|
272
|
+
"periodic": periodic,
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
return result
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def estimate_periodic_pseudotime_start(
|
|
279
|
+
adata: AnnData,
|
|
280
|
+
time_key: str = "pseudotime",
|
|
281
|
+
bandwidth: float = 1 / 64,
|
|
282
|
+
show_plot: bool = False,
|
|
283
|
+
):
|
|
284
|
+
# TODO: Test implementation
|
|
285
|
+
pseudotime = adata.obs[time_key].values.copy()
|
|
286
|
+
t_mask = ~np.isnan(pseudotime)
|
|
287
|
+
for _ in range(2):
|
|
288
|
+
rslt = density_nd(
|
|
289
|
+
pseudotime[t_mask].reshape(-1, 1),
|
|
290
|
+
bandwidth,
|
|
291
|
+
max_grid_size=2**10 + 1,
|
|
292
|
+
periodic=True,
|
|
293
|
+
bounds=((0, 1),),
|
|
294
|
+
normalize=True,
|
|
295
|
+
)
|
|
296
|
+
bspl = fit_smoothing_spline(
|
|
297
|
+
rslt.grid[:, 0],
|
|
298
|
+
1 / rslt.density,
|
|
299
|
+
t_range=(0, 1),
|
|
300
|
+
lam=1e-5,
|
|
301
|
+
periodic=True,
|
|
302
|
+
)
|
|
303
|
+
x = np.linspace(0, 1, 10001)
|
|
304
|
+
y = bspl.derivative(0)(x)
|
|
305
|
+
yp = bspl.derivative(1)(x)
|
|
306
|
+
ypp = bspl.derivative(2)(x)
|
|
307
|
+
|
|
308
|
+
if yp[np.argmax(np.abs(yp))] < 0:
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
pseudotime = -pseudotime % 1
|
|
312
|
+
else:
|
|
313
|
+
print("Warning: could not check direction for the pseudotime")
|
|
314
|
+
|
|
315
|
+
idx = np.argwhere(np.sign(ypp[:-1]) < np.sign(ypp[1:])).flatten()
|
|
316
|
+
roots = (x[idx] + x[1:][idx]) / 2
|
|
317
|
+
heights = yp[idx]
|
|
318
|
+
|
|
319
|
+
max_peak_x = roots[heights.argmin()]
|
|
320
|
+
|
|
321
|
+
if show_plot:
|
|
322
|
+
plt.hist(
|
|
323
|
+
pseudotime, bins=100, density=True, fill=False, linewidth=0.5, alpha=0.5
|
|
324
|
+
)
|
|
325
|
+
plt.plot(rslt.grid[:-1, 0], rslt.density[:-1], color="k")
|
|
326
|
+
plt.plot(x, y / np.abs(y).max())
|
|
327
|
+
plt.plot(x, yp / np.abs(yp).max())
|
|
328
|
+
plt.axvline(max_peak_x, color="k", linestyle="--")
|
|
329
|
+
plt.show()
|
|
330
|
+
|
|
331
|
+
pseudotime = (pseudotime - max_peak_x) % 1
|
|
332
|
+
adata.obs[time_key] = pseudotime
|