sclab 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sclab might be problematic. Click here for more details.

Files changed (51) hide show
  1. sclab/__init__.py +1 -1
  2. sclab/_sclab.py +10 -3
  3. sclab/dataset/_dataset.py +1 -1
  4. sclab/examples/processor_steps/__init__.py +2 -0
  5. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  6. sclab/examples/processor_steps/_integration.py +37 -4
  7. sclab/examples/processor_steps/_neighbors.py +24 -4
  8. sclab/examples/processor_steps/_pca.py +5 -5
  9. sclab/examples/processor_steps/_preprocess.py +14 -1
  10. sclab/examples/processor_steps/_qc.py +22 -6
  11. sclab/gui/__init__.py +0 -0
  12. sclab/gui/components/__init__.py +5 -0
  13. sclab/gui/components/_guided_pseudotime.py +482 -0
  14. sclab/methods/__init__.py +25 -1
  15. sclab/preprocess/__init__.py +18 -0
  16. sclab/preprocess/_cca.py +154 -0
  17. sclab/preprocess/_cca_integrate.py +77 -0
  18. sclab/preprocess/_filter_obs.py +42 -0
  19. sclab/preprocess/_harmony.py +421 -0
  20. sclab/preprocess/_harmony_integrate.py +50 -0
  21. sclab/preprocess/_normalize_weighted.py +61 -0
  22. sclab/preprocess/_subset.py +208 -0
  23. sclab/preprocess/_transfer_metadata.py +137 -0
  24. sclab/preprocess/_transform.py +82 -0
  25. sclab/preprocess/_utils.py +96 -0
  26. sclab/tools/__init__.py +0 -0
  27. sclab/tools/cellflow/__init__.py +0 -0
  28. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  29. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  30. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  31. sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
  32. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  33. sclab/tools/cellflow/utils/__init__.py +0 -0
  34. sclab/tools/cellflow/utils/density_nd.py +136 -0
  35. sclab/tools/cellflow/utils/interpolate.py +334 -0
  36. sclab/tools/cellflow/utils/smoothen.py +124 -0
  37. sclab/tools/cellflow/utils/times.py +55 -0
  38. sclab/tools/differential_expression/__init__.py +5 -0
  39. sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
  40. sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
  41. sclab/tools/doublet_detection/__init__.py +5 -0
  42. sclab/tools/doublet_detection/_scrublet.py +64 -0
  43. sclab/tools/labeling/__init__.py +6 -0
  44. sclab/tools/labeling/sctype.py +233 -0
  45. sclab/utils/__init__.py +5 -0
  46. sclab/utils/_write_excel.py +510 -0
  47. {sclab-0.2.4.dist-info → sclab-0.3.0.dist-info}/METADATA +7 -2
  48. sclab-0.3.0.dist-info/RECORD +81 -0
  49. sclab-0.2.4.dist-info/RECORD +0 -45
  50. {sclab-0.2.4.dist-info → sclab-0.3.0.dist-info}/WHEEL +0 -0
  51. {sclab-0.2.4.dist-info → sclab-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,349 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from anndata import AnnData
7
+ from numpy import floating
8
+ from numpy.typing import NDArray
9
+ from scipy.integrate import cumulative_trapezoid
10
+ from scipy.interpolate import BSpline, interp1d
11
+ from scipy.signal import find_peaks
12
+
13
+ from ..utils.density_nd import fit_density_1d
14
+ from ..utils.times import guess_trange
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def density(
20
+ adata: AnnData,
21
+ time_key: str = "pseudotime",
22
+ t_range: tuple[float, float] | None = None,
23
+ periodic: bool | None = None,
24
+ bandwidth: float = 1 / 64,
25
+ algorithm: str = "auto",
26
+ kernel: str = "gaussian",
27
+ metric: str = "euclidean",
28
+ max_grid_size: int = 2**8 + 1,
29
+ plot_density: bool = False,
30
+ plot_density_fit: bool = False,
31
+ plot_density_fit_derivative: bool = False,
32
+ plot_histogram: bool = False,
33
+ histogram_nbins: int = 50,
34
+ ):
35
+ if t_range is None and time_key in adata.uns:
36
+ # using stored t_range
37
+ t_range = adata.uns[time_key]["t_range"]
38
+ else:
39
+ # guessing t_range
40
+ pts = adata.obs[time_key].values
41
+ pts = pts[np.isfinite(pts)]
42
+ t_range = guess_trange(pts)
43
+ if pts.size < 500:
44
+ logger.warning(
45
+ "Guessing t_range may not be accurate for fewer than 500 points."
46
+ " Consider setting the pseudotime_t_range parameter instead."
47
+ )
48
+
49
+ if periodic is None and time_key in adata.uns and "periodic" in adata.uns[time_key]:
50
+ periodic = adata.uns[time_key]["periodic"]
51
+ else:
52
+ periodic = False
53
+
54
+ times = adata.obs[time_key].values
55
+ lam = 1 / max_grid_size / 1e4
56
+ rslt, bspl = fit_density_1d(
57
+ times=times,
58
+ t_range=t_range,
59
+ periodic=periodic,
60
+ bandwidth=bandwidth,
61
+ algorithm=algorithm,
62
+ kernel=kernel,
63
+ metric=metric,
64
+ max_grid_size=max_grid_size,
65
+ lam=lam,
66
+ )
67
+
68
+ if time_key not in adata.uns:
69
+ adata.uns[time_key] = {
70
+ "t_range": list(t_range),
71
+ "periodic": periodic,
72
+ }
73
+
74
+ t, c, k = bspl.tck
75
+ density_bspline_tck = dict(t=t.tolist(), c=c.tolist(), k=k)
76
+ adata.uns[time_key].update(
77
+ {
78
+ "density": {
79
+ "params": {
80
+ "bandwidth": bandwidth,
81
+ "algorithm": algorithm,
82
+ "kernel": kernel,
83
+ "metric": metric,
84
+ "max_grid_size": max_grid_size,
85
+ },
86
+ "density_bspline_tck": density_bspline_tck,
87
+ }
88
+ }
89
+ )
90
+
91
+ if plot_density | plot_density_fit | plot_density_fit_derivative | plot_histogram:
92
+ from ..utils import plot
93
+
94
+ plot.density_result_1d(
95
+ rslt,
96
+ data=times[~np.isnan(times)],
97
+ density_fit_lam=lam,
98
+ plot_density=plot_density,
99
+ plot_density_fit=plot_density_fit,
100
+ plot_density_fit_derivative=plot_density_fit_derivative,
101
+ plot_histogram=plot_histogram,
102
+ histogram_nbins=histogram_nbins,
103
+ show=True,
104
+ )
105
+
106
+
107
+ def density_dynamics(
108
+ adata: AnnData,
109
+ time_key: str = "pseudotime",
110
+ t_range: tuple[float, float] | None = None,
111
+ periodic: bool | None = None,
112
+ bandwidth: float = 1 / 64,
113
+ algorithm: str = "auto",
114
+ kernel: str = "gaussian",
115
+ metric: str = "euclidean",
116
+ max_grid_size: int = 2**8 + 1,
117
+ derivative: int = 0,
118
+ mode: Literal["peaks", "valleys"] = "peaks",
119
+ find_peaks_kwargs: dict = {},
120
+ plot_density: bool = False,
121
+ plot_density_fit: bool = False,
122
+ plot_density_fit_derivative: bool = False,
123
+ plot_histogram: bool = False,
124
+ histogram_nbins: int = 50,
125
+ ):
126
+ if t_range is None:
127
+ test = time_key in adata.uns and "t_range" in adata.uns[time_key]
128
+ assert test, f"t_range must be provided for time_key: {time_key}"
129
+ t_range = adata.uns[time_key]["t_range"]
130
+
131
+ if periodic is None:
132
+ if time_key in adata.uns and "periodic" in adata.uns[time_key]:
133
+ periodic = adata.uns[time_key]["periodic"]
134
+ else:
135
+ periodic = False
136
+
137
+ times = adata.obs[time_key].values
138
+ lam = 1 / max_grid_size / 1e4
139
+ rslt, bspl = fit_density_1d(
140
+ times=times,
141
+ t_range=t_range,
142
+ periodic=periodic,
143
+ bandwidth=bandwidth,
144
+ algorithm=algorithm,
145
+ kernel=kernel,
146
+ metric=metric,
147
+ max_grid_size=max_grid_size,
148
+ lam=lam,
149
+ )
150
+
151
+ t = np.linspace(*t_range, 2**16 + 1)
152
+ y = bspl.derivative(derivative)(t)
153
+ if mode == "peaks":
154
+ pass
155
+ elif mode == "valleys":
156
+ y = -y
157
+
158
+ tmin, tmax = t_range
159
+ tspan = tmax - tmin
160
+ if periodic:
161
+ tt = np.concatenate([t[:-1] + i * tspan for i in range(3)]) - tspan
162
+ yy = np.tile(y[:-1], 3)
163
+ else:
164
+ tt = t
165
+ yy = y
166
+
167
+ peak_height = find_peaks_kwargs.pop("height", 0.0)
168
+ peak_height = peak_height * y.max()
169
+
170
+ peaks, _ = find_peaks(yy, height=peak_height, **find_peaks_kwargs)
171
+ peak_times = tt[peaks]
172
+ peak_heights = yy[peaks]
173
+
174
+ peaks_mask = np.logical_and(peak_times >= tmin, peak_times < tmax)
175
+ peak_times = peak_times[peaks_mask]
176
+ peak_heights = peak_heights[peaks_mask]
177
+
178
+ timepoints = peak_times - tmin
179
+ if periodic:
180
+ deltas = (timepoints - np.roll(timepoints, 1)) % tspan
181
+ else:
182
+ timepoints = np.insert(timepoints, 0, 0)
183
+ deltas = timepoints[1:] - timepoints[:-1]
184
+
185
+ if time_key not in adata.uns:
186
+ adata.uns[time_key] = {}
187
+
188
+ t, c, k = bspl.tck
189
+ density_bspline_tck = dict(t=t.tolist(), c=c.tolist(), k=k)
190
+ adata.uns[time_key].update(
191
+ {
192
+ f"density_dynamics_d{derivative}_{mode}": {
193
+ "times": peak_times,
194
+ "deltas": deltas,
195
+ "heights": peak_heights,
196
+ "params": {
197
+ "bandwidth": bandwidth,
198
+ "algorithm": algorithm,
199
+ "kernel": kernel,
200
+ "metric": metric,
201
+ "max_grid_size": max_grid_size,
202
+ "find_peaks_kwargs": {"height": peak_height, **find_peaks_kwargs},
203
+ },
204
+ "density_bspline_tck": density_bspline_tck,
205
+ }
206
+ }
207
+ )
208
+
209
+ if plot_density | plot_density_fit | plot_density_fit_derivative | plot_histogram:
210
+ from ..utils import plot
211
+
212
+ ax = plot.density_result_1d(
213
+ rslt,
214
+ data=times[~np.isnan(times)],
215
+ density_fit_lam=lam,
216
+ plot_density=plot_density,
217
+ plot_density_fit=plot_density_fit,
218
+ plot_density_fit_derivative=plot_density_fit_derivative,
219
+ plot_histogram=plot_histogram,
220
+ histogram_nbins=histogram_nbins,
221
+ show=False,
222
+ )
223
+ for t in peak_times:
224
+ ax.axvline(t, color="k", linestyle="--")
225
+ plt.show()
226
+
227
+
228
+ def real_time(
229
+ adata: AnnData,
230
+ pseudotime_key: str = "pseudotime",
231
+ pseudotime_t_range: tuple[float, float] | None = None,
232
+ periodic: bool | None = None,
233
+ key_added: str = "real_time",
234
+ tmax: float = 100,
235
+ units: Literal["minutes", "hours", "days", "percent"] = "percent",
236
+ bandwidth: float = 1 / 64,
237
+ algorithm: str = "auto",
238
+ kernel: str = "gaussian",
239
+ metric: str = "euclidean",
240
+ max_grid_size: int = 2**8 + 1,
241
+ plot_density: bool = False,
242
+ plot_density_fit: bool = False,
243
+ plot_density_fit_derivative: bool = False,
244
+ plot_histogram: bool = False,
245
+ histogram_nbins: int = 50,
246
+ ):
247
+ density(
248
+ adata,
249
+ time_key=pseudotime_key,
250
+ t_range=pseudotime_t_range,
251
+ periodic=periodic,
252
+ bandwidth=bandwidth,
253
+ algorithm=algorithm,
254
+ kernel=kernel,
255
+ metric=metric,
256
+ max_grid_size=max_grid_size,
257
+ plot_density=plot_density,
258
+ plot_density_fit=plot_density_fit,
259
+ plot_density_fit_derivative=plot_density_fit_derivative,
260
+ plot_histogram=plot_histogram,
261
+ histogram_nbins=histogram_nbins,
262
+ )
263
+
264
+ time_key_uns = adata.uns[pseudotime_key]
265
+ # density function sets appropriate t_range and periodic parameters if missing
266
+ pseudotime_t_range = time_key_uns["t_range"]
267
+ periodic = time_key_uns["periodic"]
268
+ # density_bspline_tck is computed in density function
269
+ density_bspline_tck = time_key_uns["density"]["density_bspline_tck"]
270
+
271
+ pt_min, pt_tmax = pseudotime_t_range
272
+ pseudotimes = adata.obs[pseudotime_key].values
273
+ pt_mask = (pt_min <= pseudotimes) * (pseudotimes <= pt_tmax)
274
+ pseudotimes = pseudotimes[pt_mask]
275
+
276
+ rt = _area_under_curve(pseudotimes, tmax, density_bspline_tck)
277
+
278
+ adata.obs[key_added] = np.nan
279
+ adata.obs.loc[pt_mask, key_added] = rt
280
+
281
+ adata.uns[key_added] = {
282
+ "params": {
283
+ "pseudotime_key": pseudotime_key,
284
+ "pseudotime_t_range": pseudotime_t_range,
285
+ "tmax": tmax,
286
+ "units": units,
287
+ "bandwidth": bandwidth,
288
+ "algorithm": algorithm,
289
+ "kernel": kernel,
290
+ "metric": metric,
291
+ "max_grid_size": max_grid_size,
292
+ },
293
+ "density_bspline_tck": density_bspline_tck,
294
+ "tmax": tmax,
295
+ "t_range": [0.0, tmax],
296
+ "t_units": units,
297
+ "periodic": periodic,
298
+ }
299
+
300
+
301
+ def _area_under_curve(
302
+ pseudotimes: NDArray[floating], tmax: float, tck_dict: dict[str, list[float] | int]
303
+ ):
304
+ bspl = BSpline(**tck_dict)
305
+
306
+ # the normalized flux should be 1 / tmax
307
+ q = 1.0 / tmax
308
+
309
+ # we will use cumulative_trapezoid to calculate the integral
310
+ # we should make sure that we have enough points to get a good approximation
311
+ # we will use 1000 extra points evenly distributed between 0 and 1 to
312
+ # fill the gaps, and make sure to remove them after the calculation
313
+ n = 1000
314
+ x = np.concatenate([pseudotimes, np.linspace(1 / n, 1, n)])
315
+
316
+ # cumulative_trapezoid requires the x values to be sorted
317
+ o = np.argsort(x)
318
+ # we will need to sort the result back to the original order
319
+ oo = np.argsort(o)
320
+
321
+ # we need to insert 0 at the beginning, this defines the starting point
322
+ # of the integral
323
+ x = np.insert(x[o], 0, 0)
324
+ d = bspl(x)
325
+
326
+ return cumulative_trapezoid(d, x)[oo][:-n] / q
327
+
328
+
329
+ def get_realtimes(
330
+ pseudotimes: NDArray[floating], adata: AnnData, realtime_key: str = "real_time"
331
+ ):
332
+ tmax = adata.uns[realtime_key]["tmax"]
333
+ tck_dict = adata.uns[realtime_key]["density_bspline_tck"]
334
+
335
+ return _area_under_curve(pseudotimes, tmax, tck_dict)
336
+
337
+
338
+ def get_pseudotimes(
339
+ realtimes: NDArray[floating], adata: AnnData, realtime_key: str = "real_time"
340
+ ):
341
+ tmax = adata.uns[realtime_key]["tmax"]
342
+ tck_dict = adata.uns[realtime_key]["density_bspline_tck"]
343
+ pseudotime_t_range = adata.uns[realtime_key]["params"]["pseudotime_t_range"]
344
+
345
+ x = np.linspace(*pseudotime_t_range, 1001)
346
+ y = _area_under_curve(x, tmax, tck_dict)
347
+ interpolator = interp1d(y, x, kind="cubic")
348
+
349
+ return interpolator(realtimes)
File without changes
@@ -0,0 +1,332 @@
1
+ from typing import Literal, NamedTuple
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from anndata import AnnData
6
+ from numpy import bool_, floating
7
+ from numpy.typing import NDArray
8
+ from scipy.integrate import cumulative_trapezoid, quad
9
+ from tqdm.auto import tqdm
10
+
11
+ from ..utils.density_nd import density_nd
12
+ from ..utils.interpolate import NDBSpline, NDFourier, fit_smoothing_spline
13
+ from .timeseries import periodic_sliding_window
14
+
15
+ _2PI = 2 * np.pi
16
+
17
+
18
+ class PseudotimeResult(NamedTuple):
19
+ pseudotime: NDArray[floating]
20
+ residues: NDArray[floating]
21
+ phi: NDArray[floating]
22
+ F: NDFourier | NDBSpline
23
+ SNR: NDArray[floating]
24
+ snr_mask: NDArray[bool_]
25
+ t_mask: NDArray[bool_]
26
+ fp_resolution: float
27
+
28
+
29
+ def periodic_parameter(data: NDArray[floating]) -> NDArray[floating]:
30
+ x, y = data.T.astype(float)
31
+ return np.arctan2(y, x) % _2PI
32
+
33
+
34
+ def _pseudotime(
35
+ t: NDArray[floating],
36
+ X: NDArray[floating],
37
+ t_range: tuple[float, float],
38
+ n_dims: int = 10,
39
+ min_snr: float = 0.25,
40
+ periodic: bool = False,
41
+ method: Literal["fourier", "splines"] = "splines",
42
+ largest_harmonic: int = 5,
43
+ roughness: float | None = None,
44
+ progress: bool = True,
45
+ ) -> PseudotimeResult:
46
+ if not periodic:
47
+ assert method == "splines"
48
+
49
+ tmin, tmax = t_range
50
+ tspan = tmax - tmin
51
+
52
+ if periodic:
53
+ assert tmin == 0.0
54
+
55
+ match method:
56
+ case "fourier":
57
+ F = NDFourier(t_range=t_range, largest_harmonic=largest_harmonic)
58
+ case "splines":
59
+ F = NDBSpline(t_range=t_range, periodic=periodic, roughness=roughness)
60
+ case _:
61
+ raise ValueError(
62
+ f'{method} is not a valid fitting method. Choose one of: "fourier", "splines"'
63
+ )
64
+
65
+ t_mask = (tmin <= t) * (t <= tmax)
66
+ t = t[t_mask]
67
+ X = X[t_mask]
68
+
69
+ if periodic:
70
+ M = periodic_sliding_window(X, t, 50, np.median)
71
+ else:
72
+ M = X
73
+
74
+ # we fit an n-dimensional curve to the data
75
+ F.fit(t, M)
76
+
77
+ # we use the signal-to-noise ratio to assess which dimensions show a strong signal
78
+ # we only keep dimensions with some signal through the initial ordering t
79
+ SNR: NDArray[floating] = F(t).var(axis=0) / X.var(axis=0)
80
+ SNR = SNR / SNR.max()
81
+ snr_mask = SNR > min_snr
82
+
83
+ dim_mask = np.arange(X.shape[1]) < n_dims
84
+
85
+ # we remove noisy dimensions
86
+ X = X[:, snr_mask & dim_mask]
87
+ # `NDFourier` and `NDBSpline` objects can be sliced like so
88
+ full_F = F
89
+ F = F[snr_mask & dim_mask]
90
+
91
+ # we will find the closest points in the curve for each data point in X
92
+ # we do this in stages using euclidean distance
93
+ # after each stage we increase the numeric precision
94
+ n = 100
95
+ m = 10
96
+ k = 10
97
+
98
+ # T is a matrix of timepoints
99
+ # dim 0 has resolution 0.01
100
+ # dim 1 has resolution 0.0001
101
+ T = (
102
+ np.linspace(tmin, tmax, n + 1)[:-1, None]
103
+ + np.linspace(0, tspan / n, m + 1)[None]
104
+ )
105
+ # evaluate the curve points
106
+ Y = F(T)
107
+
108
+ # for each point, we find which row in T has the closest point to the curve
109
+ closest_order_1 = np.argmin(
110
+ np.linalg.norm(
111
+ X[None] - Y[:, [m // 2]],
112
+ axis=2,
113
+ ),
114
+ axis=0,
115
+ )
116
+
117
+ # for each point, we find which column in T has the closest point to the curve
118
+ closest_order_2 = np.argmin(
119
+ np.linalg.norm(
120
+ X[:, None] - Y[closest_order_1],
121
+ axis=2,
122
+ ),
123
+ axis=1,
124
+ )
125
+
126
+ # we obtain the corresponding pseudotime ordering
127
+ phi = T[closest_order_1, closest_order_2]
128
+
129
+ # so far our pseudotime estimation has resolution 0.0001
130
+ # we can refine it to match the floating point resolution of the data's dtype
131
+ fp_res = np.finfo(X.dtype).resolution
132
+ res = 1 / n / m
133
+
134
+ n_iters = int(np.floor(np.log10(res) - np.log10(fp_res)))
135
+ range_obj = range(n_iters)
136
+ if progress:
137
+ range_obj = tqdm(range_obj, bar_format="{percentage:3.0f}%|{bar}|")
138
+
139
+ for _ in range_obj:
140
+ # we create a new matrix of timepoints
141
+ T = phi[:, None] + np.linspace(-tspan * res / 2, tspan * res / 2, k + 1)
142
+
143
+ # make sure we didn't go over the range
144
+ T = T.clip(*t_range)
145
+
146
+ # and evaluate the curve points
147
+ Y = F(T)
148
+
149
+ # for each point, we find which column in T has the closest point to the curve
150
+ closest_order_3 = np.argmin(
151
+ np.linalg.norm(
152
+ X[:, None] - Y,
153
+ axis=2,
154
+ ),
155
+ axis=1,
156
+ )
157
+
158
+ # we obtain the corresponding pseudotime ordering with the current resolution
159
+ phi = T[np.arange(X.shape[0]), closest_order_3]
160
+ # update the current resolution
161
+ res = res / k
162
+
163
+ # # converts to unit vectors. returns an array of shape (n_points, n_dims)
164
+ # def unit(v):
165
+ # return v / np.linalg.norm(v, axis=-1, keepdims=True)
166
+
167
+ # # cosine of the angle between the vector from the curve to the data point
168
+ # # and the tangent vector to the curve at the closest point
169
+ # def rv_cosine(p):
170
+ # R = unit(X - F(p))
171
+ # V = unit(F(p, d=1))
172
+ # C = (V * R).sum(axis=-1)
173
+ # return C
174
+
175
+ # cosine_mask = np.abs(rv_cosine(phi)) < 0.01
176
+ # t_mask[t_mask] = cosine_mask
177
+ # phi = phi[cosine_mask]
178
+ # X = X[cosine_mask]
179
+
180
+ if periodic:
181
+ phi = phi % tspan
182
+
183
+ residues = np.linalg.norm(X - F(phi), axis=-1)
184
+
185
+ # speed returns an array of shape (n_points,)
186
+ def speed(t):
187
+ return np.linalg.norm(F(t, d=1), axis=-1)
188
+
189
+ # arclen returns a scalar
190
+ def arclen(t):
191
+ return quad(speed, tmin, t, limit=500, epsrel=1.49e-6)[0]
192
+
193
+ # we will use cumulative_trapezoid to calculate the integral
194
+ # we should make sure that we have enough points to get a good approximation
195
+ # we will use 1000 extra points evenly distributed between 0 and 1 to
196
+ # fill in the gaps, and make sure to remove them after the calculation
197
+ n = 1_000
198
+ x = np.concatenate([phi, np.linspace(tmin + 1 / n, tmax, n)])
199
+
200
+ o = np.argsort(x)
201
+ oo = np.argsort(o)
202
+ x = np.insert(x[o], 0, tmin)
203
+ pseudotime: NDArray[floating] = cumulative_trapezoid(speed(x), x=x)[oo][
204
+ :-n
205
+ ] / arclen(tmax)
206
+
207
+ return PseudotimeResult(
208
+ pseudotime,
209
+ residues,
210
+ phi,
211
+ full_F,
212
+ SNR,
213
+ snr_mask,
214
+ t_mask,
215
+ fp_res,
216
+ )
217
+
218
+
219
+ def pseudotime(
220
+ adata: AnnData,
221
+ use_rep: str,
222
+ t_key: str,
223
+ t_range: tuple[float, float],
224
+ n_dims: int = 10,
225
+ min_snr: float = 0.25,
226
+ periodic: bool = False,
227
+ method: Literal["fourier", "splines"] = "splines",
228
+ largest_harmonic: int = 5,
229
+ roughness: float | None = None,
230
+ key_added="pseudotime",
231
+ ) -> None:
232
+ X = adata.obsm[use_rep].copy().astype(float)
233
+ X_path = np.zeros_like(X)
234
+ X_path_derivative = np.zeros_like(X)
235
+ X_path_derivative_norm = np.zeros((adata.n_obs,))
236
+
237
+ t = adata.obs[t_key].values
238
+
239
+ result = _pseudotime(
240
+ t, X, t_range, n_dims, min_snr, periodic, method, largest_harmonic, roughness
241
+ )
242
+
243
+ t_mask = result.t_mask
244
+ pcs_mask = result.snr_mask
245
+ mask = t_mask[:, None] * pcs_mask
246
+
247
+ X_path[mask] = result.F[pcs_mask](result.phi).flatten()
248
+ X_path_derivative[mask] = result.F[pcs_mask](result.phi, d=1).flatten()
249
+ X_path_derivative_norm[t_mask] = np.linalg.norm(X_path_derivative[t_mask], axis=1)
250
+
251
+ adata.obs[key_added] = np.nan
252
+ adata.obs[key_added + "_path_residue"] = np.nan
253
+
254
+ adata.obs.loc[t_mask, key_added] = result.pseudotime
255
+ adata.obs.loc[t_mask, key_added + "_path_residue"] = result.residues
256
+ # adata.obs[key_added + "_path_derivative_norm"] = X_path_derivative_norm
257
+ adata.obsm[key_added + "_path"] = X_path
258
+ adata.obsm[key_added + "_path_derivative"] = X_path_derivative
259
+ adata.uns[key_added] = {
260
+ "params": {
261
+ "use_rep": use_rep,
262
+ "t_key": t_key,
263
+ "t_range": list(t_range),
264
+ "min_snr": min_snr,
265
+ "periodic": periodic,
266
+ "method": method,
267
+ "largest_harmonic": largest_harmonic,
268
+ "roughness": roughness,
269
+ },
270
+ "snr": result.SNR.tolist(),
271
+ "t_range": [0, 1],
272
+ "periodic": periodic,
273
+ }
274
+
275
+ return result
276
+
277
+
278
+ def estimate_periodic_pseudotime_start(
279
+ adata: AnnData,
280
+ time_key: str = "pseudotime",
281
+ bandwidth: float = 1 / 64,
282
+ show_plot: bool = False,
283
+ ):
284
+ # TODO: Test implementation
285
+ pseudotime = adata.obs[time_key].values.copy()
286
+ t_mask = ~np.isnan(pseudotime)
287
+ for _ in range(2):
288
+ rslt = density_nd(
289
+ pseudotime[t_mask].reshape(-1, 1),
290
+ bandwidth,
291
+ max_grid_size=2**10 + 1,
292
+ periodic=True,
293
+ bounds=((0, 1),),
294
+ normalize=True,
295
+ )
296
+ bspl = fit_smoothing_spline(
297
+ rslt.grid[:, 0],
298
+ 1 / rslt.density,
299
+ t_range=(0, 1),
300
+ lam=1e-5,
301
+ periodic=True,
302
+ )
303
+ x = np.linspace(0, 1, 10001)
304
+ y = bspl.derivative(0)(x)
305
+ yp = bspl.derivative(1)(x)
306
+ ypp = bspl.derivative(2)(x)
307
+
308
+ if yp[np.argmax(np.abs(yp))] < 0:
309
+ break
310
+
311
+ pseudotime = -pseudotime % 1
312
+ else:
313
+ print("Warning: could not check direction for the pseudotime")
314
+
315
+ idx = np.argwhere(np.sign(ypp[:-1]) < np.sign(ypp[1:])).flatten()
316
+ roots = (x[idx] + x[1:][idx]) / 2
317
+ heights = yp[idx]
318
+
319
+ max_peak_x = roots[heights.argmin()]
320
+
321
+ if show_plot:
322
+ plt.hist(
323
+ pseudotime, bins=100, density=True, fill=False, linewidth=0.5, alpha=0.5
324
+ )
325
+ plt.plot(rslt.grid[:-1, 0], rslt.density[:-1], color="k")
326
+ plt.plot(x, y / np.abs(y).max())
327
+ plt.plot(x, yp / np.abs(yp).max())
328
+ plt.axvline(max_peak_x, color="k", linestyle="--")
329
+ plt.show()
330
+
331
+ pseudotime = (pseudotime - max_peak_x) % 1
332
+ adata.obs[time_key] = pseudotime