sclab 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sclab might be problematic. Click here for more details.

Files changed (50) hide show
  1. sclab/__init__.py +1 -1
  2. sclab/dataset/_dataset.py +1 -1
  3. sclab/examples/processor_steps/__init__.py +2 -0
  4. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  5. sclab/examples/processor_steps/_integration.py +37 -4
  6. sclab/examples/processor_steps/_neighbors.py +24 -4
  7. sclab/examples/processor_steps/_pca.py +5 -5
  8. sclab/examples/processor_steps/_preprocess.py +14 -1
  9. sclab/examples/processor_steps/_qc.py +22 -6
  10. sclab/gui/__init__.py +0 -0
  11. sclab/gui/components/__init__.py +5 -0
  12. sclab/gui/components/_guided_pseudotime.py +482 -0
  13. sclab/methods/__init__.py +25 -1
  14. sclab/preprocess/__init__.py +18 -0
  15. sclab/preprocess/_cca.py +154 -0
  16. sclab/preprocess/_cca_integrate.py +77 -0
  17. sclab/preprocess/_filter_obs.py +42 -0
  18. sclab/preprocess/_harmony.py +421 -0
  19. sclab/preprocess/_harmony_integrate.py +50 -0
  20. sclab/preprocess/_normalize_weighted.py +61 -0
  21. sclab/preprocess/_subset.py +208 -0
  22. sclab/preprocess/_transfer_metadata.py +137 -0
  23. sclab/preprocess/_transform.py +82 -0
  24. sclab/preprocess/_utils.py +96 -0
  25. sclab/tools/__init__.py +0 -0
  26. sclab/tools/cellflow/__init__.py +0 -0
  27. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  28. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  29. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  30. sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
  31. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  32. sclab/tools/cellflow/utils/__init__.py +0 -0
  33. sclab/tools/cellflow/utils/density_nd.py +136 -0
  34. sclab/tools/cellflow/utils/interpolate.py +334 -0
  35. sclab/tools/cellflow/utils/smoothen.py +124 -0
  36. sclab/tools/cellflow/utils/times.py +55 -0
  37. sclab/tools/differential_expression/__init__.py +5 -0
  38. sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
  39. sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
  40. sclab/tools/doublet_detection/__init__.py +5 -0
  41. sclab/tools/doublet_detection/_scrublet.py +64 -0
  42. sclab/tools/labeling/__init__.py +6 -0
  43. sclab/tools/labeling/sctype.py +233 -0
  44. sclab/utils/__init__.py +5 -0
  45. sclab/utils/_write_excel.py +510 -0
  46. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/METADATA +6 -2
  47. sclab-0.3.0.dist-info/RECORD +81 -0
  48. sclab-0.2.5.dist-info/RECORD +0 -45
  49. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/WHEEL +0 -0
  50. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,226 @@
1
+ from typing import Callable, NamedTuple
2
+
3
+ import numpy as np
4
+ from numpy.lib.stride_tricks import sliding_window_view
5
+ from numpy.typing import NDArray
6
+ from scipy.signal import find_peaks
7
+ from scipy.sparse import csr_matrix, issparse
8
+ from tqdm.auto import tqdm
9
+
10
+ from ..utils.interpolate import NDBSpline
11
+
12
+
13
+ def periodic_sliding_window(
14
+ data: NDArray, t: NDArray, window_size: int, fn: Callable[[NDArray], NDArray]
15
+ ) -> NDArray:
16
+ ws = window_size + ((window_size - 1) % 2)
17
+ window_shape = (ws,) + (1,) * (data.ndim - 1)
18
+
19
+ o = np.argsort(t)
20
+ oo = np.argsort(o)
21
+
22
+ d = data[o]
23
+ dd = [*d[-ws // 2 + 1 :], *d, *d[: ws // 2]]
24
+
25
+ windows = sliding_window_view(dd, window_shape=window_shape).squeeze()
26
+ return fn(windows, axis=-1)[oo]
27
+
28
+
29
+ def equalization(
30
+ times: NDArray,
31
+ t_range: tuple[float, float],
32
+ max_bins: int = 200,
33
+ iterations: int = 1e4,
34
+ tolerance: float = 0.02,
35
+ ) -> NDArray:
36
+ if not isinstance(times, np.ndarray):
37
+ raise TypeError("times must be a numpy array")
38
+
39
+ if times.ndim != 1:
40
+ raise ValueError("times must be a 1D array")
41
+
42
+ t_min, t_max = t_range
43
+ t_span = t_max - t_min
44
+
45
+ # for sorting the values
46
+ o = np.argsort(times)
47
+ # and recovering the original order
48
+ oo = np.argsort(o)
49
+
50
+ alpha = 0.1
51
+ scale_offset = 1
52
+
53
+ rng = np.random.default_rng()
54
+ scaled_times = times.copy()
55
+
56
+ for n_bins in tqdm(np.arange(25, max_bins + 1, 25)):
57
+ for it in range(int(iterations)):
58
+ bins = np.linspace(t_min, t_max, n_bins + 1)
59
+ bins[1:-1] += rng.normal(0, t_span / n_bins / 100, bins[1:-1].size)
60
+ counts, _ = np.histogram(scaled_times, bins=bins)
61
+ tmp: NDArray = counts / counts.max()
62
+ rms = np.sqrt(np.mean((tmp - tmp.mean()) ** 2))
63
+ if rms < tolerance:
64
+ break
65
+
66
+ scales = counts / counts.max() * alpha + scale_offset
67
+
68
+ t = scaled_times[o]
69
+ tt = []
70
+ i = 0
71
+ timepoint = 0.0
72
+ for start, end, scale in zip(bins[:-1], bins[1:], scales):
73
+ bin_size = end - start
74
+ new_size = bin_size * scale
75
+ while i < t.size and t[i] < end:
76
+ new_t = (t[i] - start) * scale + timepoint
77
+ tt.append(new_t)
78
+ i += 1
79
+ timepoint += new_size
80
+
81
+ tt = np.array(tt)
82
+ scaled_times = tt[oo] / timepoint * t_span + t_min
83
+
84
+ else:
85
+ cnts_mean, cnts_max, cnts_min = counts.mean(), counts.max(), counts.min()
86
+ print(
87
+ f"Failed to converge. RMS: {rms}. "
88
+ + f"({cnts_mean=:.2f}, {cnts_max=:.2f}, {cnts_min=:.2f})"
89
+ )
90
+
91
+ return scaled_times
92
+
93
+
94
+ def fit_trends(
95
+ X: NDArray | csr_matrix,
96
+ times: NDArray,
97
+ t_range: tuple[float, float],
98
+ periodic: bool,
99
+ grid_size: int = 128,
100
+ roughness: float | None = None,
101
+ zero_weight: float = 0.5,
102
+ window_width: float | None = None,
103
+ n_timesteps: int | None = None,
104
+ timestep_delta: float | None = None,
105
+ progress: bool = True,
106
+ ) -> None:
107
+ if issparse(X):
108
+ X = np.ascontiguousarray(X.todense())
109
+
110
+ tmin, tmax = t_range
111
+
112
+ mask = ~np.isnan(times)
113
+ t = times[mask]
114
+ X = X[mask]
115
+
116
+ F = NDBSpline(
117
+ grid_size=grid_size,
118
+ t_range=t_range,
119
+ periodic=periodic,
120
+ zero_weight=zero_weight,
121
+ roughness=roughness,
122
+ window_width=window_width,
123
+ )
124
+ F.fit(t, X, progress=progress)
125
+
126
+ eps = np.finfo(float).eps
127
+ SNR: NDArray
128
+ SNR = F(t).var(axis=0) / (X.var(axis=0) + eps)
129
+ SNR = SNR / SNR.max()
130
+
131
+ # x = np.linspace(*t_range, 10001)[:-1]
132
+ # peak_time = x[np.argmax(F(x), axis=0)]
133
+
134
+ if n_timesteps is not None and timestep_delta is not None:
135
+ raise ValueError("Cannot specify both n_timesteps and timestep_delta")
136
+ elif n_timesteps is None and timestep_delta is None:
137
+ # default
138
+ x = np.linspace(*t_range, 101)
139
+ elif n_timesteps is not None:
140
+ x = np.linspace(*t_range, n_timesteps)
141
+ elif timestep_delta is not None:
142
+ x = np.arange(tmin, tmax + timestep_delta, timestep_delta)
143
+
144
+ Y = F(x)
145
+
146
+ return x, Y
147
+
148
+
149
+ class SinglePeakResult(NamedTuple):
150
+ times: NDArray
151
+ heights: NDArray
152
+ scores: NDArray
153
+ info: NDArray
154
+
155
+
156
+ def find_single_peaks(
157
+ X: NDArray,
158
+ t: NDArray,
159
+ t_range: tuple[float, float] = (0, 1),
160
+ grid_size: int = 512,
161
+ periodic: bool = True,
162
+ zero_weight: float = 0.2,
163
+ roughness: float = 2,
164
+ n_timesteps: int = 201,
165
+ width_range: tuple[float, float] = (0, 100),
166
+ score_threshold: float = 2.5,
167
+ progress: bool = True,
168
+ ) -> tuple[NDArray, NDArray]:
169
+ X = X / np.percentile(X + 1, 99, axis=0, keepdims=True)
170
+ x, Y = fit_trends(
171
+ X,
172
+ t,
173
+ t_range=t_range,
174
+ periodic=periodic,
175
+ grid_size=grid_size,
176
+ zero_weight=zero_weight,
177
+ roughness=roughness,
178
+ n_timesteps=n_timesteps,
179
+ progress=progress,
180
+ )
181
+
182
+ peak_times = np.full(X.shape[1], np.nan)
183
+ peak_heights = np.full(X.shape[1], np.nan)
184
+ peak_scores = np.full(X.shape[1], np.nan)
185
+ peak_info_data = [{}] * X.shape[1]
186
+
187
+ idx_sequence = range(X.shape[1])
188
+ if progress:
189
+ idx_sequence = tqdm(idx_sequence)
190
+
191
+ for i in idx_sequence:
192
+ y = Y[:, i]
193
+ k, info = find_peaks(y, prominence=0.05, width=width_range, height=0)
194
+ m = np.median(y)
195
+ s = y[k] / m
196
+ k = k[s > score_threshold]
197
+ if len(k) == 1:
198
+ peak_times[i] = x[k]
199
+ peak_heights[i] = y[k]
200
+ peak_scores[i] = np.log2(s[0])
201
+ peak_info_data[i] = info
202
+
203
+ return SinglePeakResult(peak_times, peak_heights, peak_scores, peak_info_data)
204
+
205
+
206
+ def piecewise_scaling(
207
+ times: NDArray,
208
+ t_range: tuple[float, float],
209
+ start: float,
210
+ end: float,
211
+ new_end: float,
212
+ ) -> NDArray:
213
+ tmin, tmax = t_range
214
+
215
+ times_pws = np.full(times.shape, np.nan)
216
+
217
+ mask = (times >= tmin) & (times < start)
218
+ times_pws[mask] = times[mask]
219
+
220
+ mask = (times >= start) & (times < end)
221
+ times_pws[mask] = (times[mask] - start) / (end - start) * (new_end - start) + start
222
+
223
+ mask = (times >= end) & (times < tmax)
224
+ times_pws[mask] = (times[mask] - end) / (tmax - end) * (tmax - new_end) + new_end
225
+
226
+ return times_pws
File without changes
@@ -0,0 +1,136 @@
1
+ from itertools import product
2
+ from typing import Literal, NamedTuple
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+ from scipy.integrate import trapezoid
7
+ from scipy.interpolate import BSpline
8
+ from sklearn.neighbors import KernelDensity
9
+
10
+ from .interpolate import fit_smoothing_spline
11
+
12
+
13
+ class DensityResult(NamedTuple):
14
+ kde: KernelDensity
15
+ grid_size: int
16
+ bounds: tuple[tuple[float, float], ...]
17
+ grid: NDArray
18
+ density: NDArray
19
+ scale: float
20
+ periodic: bool
21
+
22
+
23
+ def density_nd(
24
+ data: NDArray,
25
+ bandwidth: float | Literal["scott", "silverman"] | None = None,
26
+ algorithm: Literal["kd_tree", "ball_tree", "auto"] = "auto",
27
+ kernel: str = "gaussian",
28
+ metric: str = "euclidean",
29
+ grid_size: tuple | None = None,
30
+ max_grid_size: int = 2**5 + 1,
31
+ periodic: bool = False,
32
+ bounds: tuple[tuple[float, float], ...] | None = None,
33
+ normalize: bool = False,
34
+ ) -> DensityResult:
35
+ if data.ndim == 1:
36
+ data = data.reshape(-1, 1)
37
+
38
+ nsamples, ndims = data.shape
39
+ if bounds is None:
40
+ assert not periodic, "bounds must be specified if periodic=True"
41
+ lower, upper = data.min(axis=0), data.max(axis=0)
42
+ span = upper - lower
43
+ margins = span / 10
44
+ bounds = tuple(zip(lower - margins, upper + margins))
45
+ assert len(bounds) == ndims, "must provide bounds for each dimension"
46
+
47
+ if periodic:
48
+ offsets = np.array(list(product([-1, 0, 1], repeat=ndims)))
49
+ offsets = offsets * np.diff(bounds).T
50
+ dat = np.empty((nsamples * 3**ndims, ndims))
51
+ for i, offset in enumerate(offsets):
52
+ dat[i * nsamples : (i + 1) * nsamples] = data + offset[None, :]
53
+ else:
54
+ dat = data
55
+
56
+ if bandwidth is None:
57
+ bandwidth = np.diff(bounds).max() / 64
58
+
59
+ kde = KernelDensity(
60
+ bandwidth=bandwidth,
61
+ algorithm=algorithm,
62
+ kernel=kernel,
63
+ metric=metric,
64
+ )
65
+ kde.fit(dat)
66
+
67
+ if grid_size is None:
68
+ max_span = np.diff(bounds).max()
69
+ rel_span = np.diff(bounds).flatten() / max_span
70
+ grid_size = tuple((rel_span * max_grid_size).astype(int))
71
+
72
+ grid = np.meshgrid(
73
+ *[np.linspace(*b, n) for b, n in zip(bounds, grid_size)], indexing="ij"
74
+ )
75
+ grid = np.vstack([x.ravel() for x in grid]).T
76
+ d = np.exp(kde.score_samples(grid))
77
+
78
+ if normalize and ndims == 1:
79
+ scale = trapezoid(d, grid.reshape(-1))
80
+ elif normalize:
81
+ # perform simple Riemmann sum for higher dimensions
82
+ deltas = np.diff(bounds).T / (np.array(grid_size) - 1)
83
+ tmp = d.reshape(grid_size).copy()
84
+ for i, s in enumerate(grid_size):
85
+ # take left corners for the sum
86
+ tmp = tmp.take(np.arange(s - 1), axis=i)
87
+ scale = tmp.sum() * np.prod(deltas)
88
+ else:
89
+ scale = 1
90
+
91
+ d /= scale
92
+
93
+ return DensityResult(kde, grid_size, bounds, grid, d, scale, periodic)
94
+
95
+
96
+ def fit_density_1d(
97
+ times: NDArray[np.floating],
98
+ t_range: tuple[float, float],
99
+ periodic: bool,
100
+ bandwidth: float | None = None,
101
+ algorithm: str = "auto",
102
+ kernel: str = "gaussian",
103
+ metric: str = "euclidean",
104
+ max_grid_size: int = 2**8 + 1,
105
+ lam: float = 1e-5,
106
+ ) -> tuple[DensityResult, BSpline]:
107
+ tmin, tmax = t_range
108
+ tspan = tmax - tmin
109
+
110
+ times_mask = (tmin <= times) * (times <= tmax)
111
+ times = times[times_mask]
112
+
113
+ if bandwidth is None:
114
+ bandwidth = tspan / 64
115
+
116
+ rslt = density_nd(
117
+ times.reshape(-1, 1),
118
+ bandwidth=bandwidth,
119
+ algorithm=algorithm,
120
+ kernel=kernel,
121
+ metric=metric,
122
+ max_grid_size=max_grid_size,
123
+ periodic=periodic,
124
+ bounds=(t_range,),
125
+ normalize=True,
126
+ )
127
+
128
+ bspl = fit_smoothing_spline(
129
+ rslt.grid[:, 0],
130
+ rslt.density,
131
+ t_range,
132
+ lam=lam,
133
+ periodic=periodic,
134
+ )
135
+
136
+ return rslt, bspl
@@ -0,0 +1,334 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ import numpy as np
5
+ from numpy import asarray, ascontiguousarray, floating, prod
6
+ from numpy import empty as np_empty
7
+ from numpy import float64 as np_float64
8
+ from numpy.typing import NDArray
9
+ from scipy.fft import fft, fftfreq
10
+ from scipy.interpolate import BSpline, _fitpack_impl, make_smoothing_spline
11
+ from tqdm.auto import tqdm
12
+
13
+ from .smoothen import choose_grid_size, count_data_in_intervals, smoothen_data
14
+
15
+ try:
16
+ from scipy.interpolate._dierckx import evaluate_spline
17
+ except ImportError:
18
+ from scipy.interpolate._bspl import evaluate_spline
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+ PIX2 = 2 * np.pi
23
+
24
+
25
+ def fit_smoothing_spline(
26
+ x: NDArray[floating],
27
+ y: NDArray[floating],
28
+ t_range: tuple[float, float],
29
+ w: NDArray[floating] | None = None,
30
+ lam: float | None = None,
31
+ periodic: bool = False,
32
+ n_reps: int = 3,
33
+ ) -> BSpline:
34
+ if periodic:
35
+ assert n_reps % 2 == 1
36
+
37
+ o = np.argsort(x)
38
+ x, y = x[o], y[o]
39
+ if w is not None:
40
+ w = w[o]
41
+
42
+ tmin, tmax = t_range
43
+ tspan = tmax - tmin
44
+
45
+ if periodic:
46
+ mask = np.logical_and((x >= tmin), (x < tmax))
47
+ else:
48
+ mask = np.logical_and((x >= tmin), (x <= tmax))
49
+
50
+ x, y = x[mask], y[mask]
51
+ if w is not None:
52
+ w = w[mask]
53
+ n = x.size
54
+
55
+ if periodic:
56
+ xx = np.concatenate([x + i * tspan for i in range(n_reps)])
57
+ yy = np.tile(y, n_reps)
58
+ ww = np.tile(w, n_reps) if w is not None else None
59
+ else:
60
+ xx = x
61
+ yy = y
62
+ ww = w
63
+
64
+ bspl = make_smoothing_spline(xx, yy, ww, lam)
65
+ t, c, k = bspl.tck
66
+
67
+ if periodic:
68
+ N = n_reps // 2
69
+ t = t - tspan * N
70
+ t = t[n * N : -n * N + 1]
71
+ c = c[n * N : -n * N + 1]
72
+
73
+ return BSpline(t, c, k)
74
+
75
+
76
+ class NDFourier:
77
+ def __init__(
78
+ self,
79
+ xh: NDArray[floating] | None = None,
80
+ freq: NDArray[floating] | None = None,
81
+ t_range: tuple[float, float] | None = None,
82
+ grid_size: int | None = None,
83
+ periodic: bool = True,
84
+ largest_harmonic: int = 5,
85
+ d: int = 0,
86
+ zero_weight: float = 1.0,
87
+ smoothing_fn: Callable = np.average,
88
+ ) -> None:
89
+ assert periodic
90
+ assert t_range is not None
91
+ assert t_range[0] == 0
92
+
93
+ self.tmin, self.tmax = self.t_range = t_range
94
+ self.tscale = PIX2 / self.tmax
95
+
96
+ if xh is not None:
97
+ assert freq is not None
98
+ self.n = grid_size + 1
99
+ self.xh = xh.reshape((xh.shape[0], -1, 1)).copy()
100
+ self.freq = freq.reshape((freq.shape[0], -1, 1)).copy()
101
+ self.scaled_freq = 1j * self.freq * self.tscale
102
+
103
+ self.grid_size = grid_size
104
+ self.periodic = periodic
105
+ self.largest_harmonic = largest_harmonic
106
+ self.d = d
107
+ self.zero_weight = zero_weight
108
+ self.smoothing_fn = smoothing_fn
109
+
110
+ def fit(
111
+ self,
112
+ t: NDArray[floating],
113
+ X: NDArray[floating],
114
+ ) -> "NDFourier":
115
+ if self.grid_size is None:
116
+ self.grid_size = choose_grid_size(t, self.t_range)
117
+
118
+ t_grid = np.linspace(*self.t_range, self.grid_size + 1)
119
+ self.X_smooth = smoothen_data(
120
+ t,
121
+ X,
122
+ t_range=self.t_range,
123
+ t_grid=t_grid,
124
+ periodic=self.periodic,
125
+ zero_weight=self.zero_weight,
126
+ fn=self.smoothing_fn,
127
+ )
128
+
129
+ self.n = n = self.X_smooth.shape[0]
130
+ self.X_smooth = self.X_smooth.reshape((n, -1))
131
+
132
+ xh: NDArray[floating] = fft(self.X_smooth, axis=0)
133
+ freq: NDArray[floating] = fftfreq(n, d=1 / n)
134
+
135
+ mask = np.abs(freq) <= self.largest_harmonic
136
+ xh = xh[mask]
137
+ freq = freq[mask]
138
+
139
+ self.xh = xh.reshape((xh.shape[0], -1, 1))
140
+ self.freq = freq.reshape((freq.shape[0], -1, 1))
141
+ self.scaled_freq = 1j * self.freq * self.tscale
142
+
143
+ return self
144
+
145
+ def derivative(self, d=1) -> "NDFourier":
146
+ return NDFourier(
147
+ self.xh,
148
+ self.freq,
149
+ self.t_range,
150
+ self.grid_size,
151
+ self.periodic,
152
+ self.largest_harmonic,
153
+ d + self.d,
154
+ )
155
+
156
+ def __getitem__(self, key) -> "NDFourier":
157
+ return NDFourier(
158
+ self.xh[:, key],
159
+ self.freq,
160
+ self.t_range,
161
+ self.grid_size,
162
+ self.periodic,
163
+ self.largest_harmonic,
164
+ self.d,
165
+ )
166
+
167
+ def __call__(self, x: NDArray[floating], d=0) -> NDArray[floating]:
168
+ x = asarray(x)
169
+ x_shape = x.shape
170
+
171
+ x = ascontiguousarray(x.ravel(), dtype=np_float64)
172
+
173
+ d = d + self.d
174
+ out: NDArray[floating] = np.real(
175
+ (self.xh * self.scaled_freq**d * np.exp(self.scaled_freq * x)).sum(axis=0)
176
+ / self.n
177
+ )
178
+ out = out.T
179
+ out = out.reshape(x_shape + (self.xh.shape[1],))
180
+
181
+ return out
182
+
183
+
184
+ class NDBSpline:
185
+ def __init__(
186
+ self,
187
+ t: NDArray[floating] | None = None,
188
+ C: NDArray[floating] | None = None,
189
+ k: int | None = None,
190
+ t_range: tuple[float, float] | None = None,
191
+ grid_size: int | None = None,
192
+ periodic: bool = False,
193
+ roughness: float | None = None,
194
+ zero_weight: float = 1.0,
195
+ window_width: float | None = None,
196
+ use_grid: bool = True,
197
+ weight_grid: bool = False,
198
+ smoothing_fn: Callable = np.average,
199
+ ) -> None:
200
+ if periodic:
201
+ assert t_range is not None
202
+ assert t_range[0] == 0
203
+
204
+ if t is not None or C is not None or k is not None:
205
+ assert t is not None
206
+ assert C is not None
207
+ assert k is not None
208
+ self.t = t.copy()
209
+ self.C = C.reshape((C.shape[0], -1)).copy()
210
+ self.k = k
211
+
212
+ if t_range is not None:
213
+ self.tmin, self.tmax = self.t_range = t_range
214
+
215
+ self.grid_size = grid_size
216
+ self.periodic = periodic
217
+ self.window_width = window_width
218
+ self.use_grid = use_grid
219
+ self.weight_grid = weight_grid
220
+ self.roughness = roughness
221
+ self.zero_weight = zero_weight
222
+ self.smoothing_fn = smoothing_fn
223
+
224
+ def fit(
225
+ self,
226
+ t: NDArray[floating],
227
+ X: NDArray[floating],
228
+ progress: bool = False,
229
+ ) -> "NDBSpline":
230
+ X = X.reshape((X.shape[0], -1))
231
+ if self.t_range is None:
232
+ self.tmin, self.tmax = self.t_range = t.min(), t.max()
233
+
234
+ if self.grid_size is None:
235
+ self.grid_size = choose_grid_size(t, self.t_range)
236
+
237
+ if self.roughness is None:
238
+ self.roughness = 1
239
+
240
+ if self.use_grid:
241
+ t_grid: NDArray[floating] = np.linspace(*self.t_range, self.grid_size + 1)
242
+ self.lam = 1 / self.grid_size / 10**self.roughness
243
+ else:
244
+ t_grid = None
245
+ self.lam = 1 / 10**self.roughness
246
+ self.X_smooth = smoothen_data(
247
+ t,
248
+ X,
249
+ t_range=self.t_range,
250
+ t_grid=t_grid,
251
+ periodic=self.periodic,
252
+ window_width=self.window_width,
253
+ zero_weight=self.zero_weight,
254
+ progress=progress,
255
+ fn=self.smoothing_fn,
256
+ )
257
+
258
+ if t_grid is not None and self.weight_grid:
259
+ w = np.zeros(self.X_smooth.shape[0], dtype=float)
260
+ n = count_data_in_intervals(t, t_grid) + 1
261
+ if self.periodic:
262
+ n = np.append(n, n[0])
263
+ else:
264
+ n = np.append(n, n[-1])
265
+ w[n > 1] = 1 / np.log1p(n[n > 1])
266
+ else:
267
+ w = None
268
+
269
+ iterator = self.X_smooth.T
270
+ if progress:
271
+ iterator = tqdm(
272
+ iterator,
273
+ bar_format="{desc} {percentage:3.0f}%|{bar}|",
274
+ desc="Fitting bsplines",
275
+ )
276
+
277
+ fit_t_range = (0, 1)
278
+ fit_t_grid = np.linspace(0, 1, self.grid_size + 1)
279
+ fit_t = (t - self.tmin) / (self.tmax - self.tmin)
280
+ C = []
281
+ for x in iterator:
282
+ f = fit_smoothing_spline(
283
+ fit_t_grid if self.use_grid else fit_t,
284
+ x,
285
+ t_range=fit_t_range,
286
+ w=w,
287
+ lam=self.lam,
288
+ periodic=self.periodic,
289
+ )
290
+
291
+ C.append(f.c)
292
+
293
+ self.t = f.t.copy()
294
+ self.t *= self.tmax - self.tmin
295
+ self.t += self.tmin
296
+ self.C = np.array(C).T.copy()
297
+ self.k = 3
298
+
299
+ return self
300
+
301
+ def derivative(self, d: int = 1) -> "NDBSpline":
302
+ # pad the c array if needed
303
+ ct = len(self.t) - len(self.C)
304
+ if ct > 0:
305
+ self.C = np.r_[self.C, np.zeros((ct, self.C.shape[1]))]
306
+ t, C, k = _fitpack_impl.splder((self.t, self.C, self.k), d)
307
+ return NDBSpline(t, C, k, self.t_range, self.grid_size, self.periodic)
308
+
309
+ def __getitem__(self, key) -> "NDBSpline":
310
+ t = self.t
311
+ C = self.C[:, key]
312
+ k = self.k
313
+ return NDBSpline(t, C, k, self.t_range, self.grid_size, self.periodic)
314
+
315
+ def __call__(self, x: NDArray[floating], d: int = 0) -> NDArray[floating]:
316
+ x = asarray(x)
317
+ x_shape = x.shape
318
+
319
+ x = ascontiguousarray(x.ravel(), dtype=np_float64)
320
+ if self.periodic:
321
+ n = self.t.size - self.k - 1
322
+ x = self.t[self.k] + (x - self.t[self.k]) % (self.t[n] - self.t[self.k])
323
+
324
+ out = np_empty((len(x), prod(self.C.shape[1:])), dtype=self.C.dtype)
325
+
326
+ if not self.t.flags.c_contiguous:
327
+ self.t = self.t.copy()
328
+ if not self.C.flags.c_contiguous:
329
+ self.C = self.C.copy()
330
+
331
+ evaluate_spline(self.t, self.C, self.k, x, d, False, out)
332
+ out = out.reshape(x_shape + (self.C.shape[1],))
333
+
334
+ return out