sclab 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sclab might be problematic. Click here for more details.
- sclab/__init__.py +1 -1
- sclab/_sclab.py +7 -3
- sclab/dataset/_dataset.py +1 -1
- sclab/dataset/processor/_processor.py +19 -4
- sclab/examples/processor_steps/__init__.py +2 -0
- sclab/examples/processor_steps/_doublet_detection.py +68 -0
- sclab/examples/processor_steps/_integration.py +47 -20
- sclab/examples/processor_steps/_neighbors.py +24 -4
- sclab/examples/processor_steps/_pca.py +11 -6
- sclab/examples/processor_steps/_preprocess.py +14 -1
- sclab/examples/processor_steps/_qc.py +22 -6
- sclab/gui/__init__.py +0 -0
- sclab/gui/components/__init__.py +7 -0
- sclab/gui/components/_guided_pseudotime.py +482 -0
- sclab/gui/components/_transfer_metadata.py +186 -0
- sclab/methods/__init__.py +16 -0
- sclab/preprocess/__init__.py +19 -0
- sclab/preprocess/_cca.py +154 -0
- sclab/preprocess/_cca_integrate.py +109 -0
- sclab/preprocess/_filter_obs.py +42 -0
- sclab/preprocess/_harmony.py +421 -0
- sclab/preprocess/_harmony_integrate.py +53 -0
- sclab/preprocess/_normalize_weighted.py +61 -0
- sclab/preprocess/_subset.py +208 -0
- sclab/preprocess/_transfer_metadata.py +137 -0
- sclab/preprocess/_transform.py +82 -0
- sclab/preprocess/_utils.py +96 -0
- sclab/tools/__init__.py +0 -0
- sclab/tools/cellflow/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
- sclab/tools/cellflow/pseudotime/__init__.py +0 -0
- sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
- sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
- sclab/tools/cellflow/utils/__init__.py +0 -0
- sclab/tools/cellflow/utils/density_nd.py +215 -0
- sclab/tools/cellflow/utils/interpolate.py +334 -0
- sclab/tools/cellflow/utils/smoothen.py +124 -0
- sclab/tools/cellflow/utils/times.py +55 -0
- sclab/tools/differential_expression/__init__.py +5 -0
- sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
- sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
- sclab/tools/doublet_detection/__init__.py +5 -0
- sclab/tools/doublet_detection/_scrublet.py +64 -0
- sclab/tools/labeling/__init__.py +6 -0
- sclab/tools/labeling/sctype.py +233 -0
- sclab/utils/__init__.py +5 -0
- sclab/utils/_write_excel.py +510 -0
- {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/METADATA +6 -2
- sclab-0.3.1.dist-info/RECORD +82 -0
- sclab-0.2.5.dist-info/RECORD +0 -45
- {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/WHEEL +0 -0
- {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import plotly.express as px
|
|
4
|
+
import plotly.graph_objects as go
|
|
5
|
+
from ipywidgets import Button, Dropdown, FloatLogSlider, FloatSlider, HBox, Text
|
|
6
|
+
from numpy import floating
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from plotly.subplots import make_subplots
|
|
9
|
+
from sclab.dataset.processor import Processor
|
|
10
|
+
from sclab.dataset.processor.step import ProcessorStepBase
|
|
11
|
+
|
|
12
|
+
# TODO: remove self.drawn_path and self.drawn_path_residue from the class
|
|
13
|
+
# and add them to the dataset as _drawn_path and _drawn_path_residue
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_2PI = 2 * np.pi
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GuidedPseudotime(ProcessorStepBase):
|
|
20
|
+
parent: Processor
|
|
21
|
+
name: str = "guided_pseudotime"
|
|
22
|
+
description: str = "Guided Pseudotime"
|
|
23
|
+
|
|
24
|
+
run_button_description = "Compute Pseudotime"
|
|
25
|
+
|
|
26
|
+
def __init__(self, parent: Processor) -> None:
|
|
27
|
+
variable_controls = dict(
|
|
28
|
+
residue_threshold=FloatLogSlider(
|
|
29
|
+
value=1,
|
|
30
|
+
min=-3,
|
|
31
|
+
max=0,
|
|
32
|
+
description="Filter by Dist.",
|
|
33
|
+
continuous_update=True,
|
|
34
|
+
),
|
|
35
|
+
use_rep=Dropdown(
|
|
36
|
+
options=tuple(parent.dataset.adata.obsm.keys()),
|
|
37
|
+
value=None,
|
|
38
|
+
description="Use rep.",
|
|
39
|
+
),
|
|
40
|
+
roughness=FloatSlider(
|
|
41
|
+
value=0,
|
|
42
|
+
min=0,
|
|
43
|
+
max=3,
|
|
44
|
+
step=0.05,
|
|
45
|
+
description="Roughness",
|
|
46
|
+
continuous_update=False,
|
|
47
|
+
),
|
|
48
|
+
min_snr=FloatSlider(
|
|
49
|
+
value=0.25,
|
|
50
|
+
min=0,
|
|
51
|
+
max=1,
|
|
52
|
+
step=0.05,
|
|
53
|
+
description="SNR",
|
|
54
|
+
continuous_update=False,
|
|
55
|
+
),
|
|
56
|
+
key_added=Text(
|
|
57
|
+
value="pseudotime",
|
|
58
|
+
description="Key added",
|
|
59
|
+
placeholder="",
|
|
60
|
+
),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
super().__init__(
|
|
64
|
+
parent=parent,
|
|
65
|
+
fixed_params={},
|
|
66
|
+
variable_controls=variable_controls,
|
|
67
|
+
)
|
|
68
|
+
self.run_button.disabled = True
|
|
69
|
+
self.run_button.button_style = ""
|
|
70
|
+
self.estimate_start_time_button.layout.visibility = "hidden"
|
|
71
|
+
self.estimate_start_time_button.layout.height = "0px"
|
|
72
|
+
|
|
73
|
+
self.start_drawing_button.on_click(self.toggle_drawing_callback)
|
|
74
|
+
self.plot_signal.on_click(self.send_signal_plot)
|
|
75
|
+
self.estimate_start_time_button.on_click(
|
|
76
|
+
self.estimate_periodic_pseudotime_start
|
|
77
|
+
)
|
|
78
|
+
self.plot_fitted_pseudotime_curve.on_click(self.send_fitted_pseudotime_plot)
|
|
79
|
+
self.variable_controls["use_rep"].observe(
|
|
80
|
+
self.update_buttons_state, names="value", type="change"
|
|
81
|
+
)
|
|
82
|
+
self.variable_controls["residue_threshold"].observe(
|
|
83
|
+
self._assign_drawn_path_values, names="value", type="change"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.variable_controls["use_rep"].observe(
|
|
87
|
+
self.close_signal_plot, names="value", type="change"
|
|
88
|
+
)
|
|
89
|
+
self.variable_controls["roughness"].observe(
|
|
90
|
+
self.close_signal_plot, names="value", type="change"
|
|
91
|
+
)
|
|
92
|
+
self.variable_controls["min_snr"].observe(
|
|
93
|
+
self.close_signal_plot, names="value", type="change"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def make_controls(self):
|
|
97
|
+
self.start_drawing_button = Button(
|
|
98
|
+
description="Start Drawing", button_style="primary"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
self.auto_drawing_button = Button(
|
|
102
|
+
description="Automatic Drawing", button_style="primary"
|
|
103
|
+
)
|
|
104
|
+
self.auto_drawing_button.on_click(self._automatic_periodic_path_drawing)
|
|
105
|
+
|
|
106
|
+
self.plot_signal = Button(
|
|
107
|
+
description="Plot Signal", button_style="info", disabled=True
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
self.estimate_start_time_button = Button(
|
|
111
|
+
description="Estimate Pseudotime Start",
|
|
112
|
+
button_style="",
|
|
113
|
+
disabled=True,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.plot_fitted_pseudotime_curve = Button(
|
|
117
|
+
description="Plot Fitted Pseudotime Curve",
|
|
118
|
+
button_style="info",
|
|
119
|
+
disabled=True,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.controls_list = [
|
|
123
|
+
HBox([self.auto_drawing_button, self.start_drawing_button]),
|
|
124
|
+
*self.variable_controls.values(),
|
|
125
|
+
self.plot_signal,
|
|
126
|
+
self.run_button,
|
|
127
|
+
self.plot_fitted_pseudotime_curve,
|
|
128
|
+
self.estimate_start_time_button,
|
|
129
|
+
self.output,
|
|
130
|
+
]
|
|
131
|
+
super().make_controls()
|
|
132
|
+
|
|
133
|
+
def update_buttons_state(self, *args, **kwargs):
|
|
134
|
+
drawing = self.start_drawing_button.description != "Start Drawing"
|
|
135
|
+
if self.variable_controls["use_rep"].value is None or drawing:
|
|
136
|
+
self.run_button.disabled = True
|
|
137
|
+
self.run_button.button_style = ""
|
|
138
|
+
|
|
139
|
+
self.plot_signal.disabled = True
|
|
140
|
+
self.plot_signal.button_style = ""
|
|
141
|
+
|
|
142
|
+
self.plot_fitted_pseudotime_curve.disabled = True
|
|
143
|
+
self.plot_fitted_pseudotime_curve.button_style = ""
|
|
144
|
+
|
|
145
|
+
self.estimate_start_time_button.disabled = True
|
|
146
|
+
self.estimate_start_time_button.button_style = ""
|
|
147
|
+
self.estimate_start_time_button.layout.visibility = "hidden"
|
|
148
|
+
self.estimate_start_time_button.layout.height = "0px"
|
|
149
|
+
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
self.run_button.disabled = False
|
|
153
|
+
self.run_button.button_style = "primary"
|
|
154
|
+
|
|
155
|
+
self.plot_signal.disabled = False
|
|
156
|
+
self.plot_signal.button_style = "info"
|
|
157
|
+
|
|
158
|
+
self.plot_fitted_pseudotime_curve.disabled = False
|
|
159
|
+
self.plot_fitted_pseudotime_curve.button_style = "info"
|
|
160
|
+
|
|
161
|
+
def toggle_drawing_callback(self, _: Button | None = None):
|
|
162
|
+
if self.start_drawing_button.description == "Start Drawing":
|
|
163
|
+
self.start_drawing_button.disabled = False
|
|
164
|
+
self.start_drawing_button.button_style = "warning"
|
|
165
|
+
self.start_drawing_button.description = "--> click here when ready <--"
|
|
166
|
+
|
|
167
|
+
self.update_buttons_state()
|
|
168
|
+
|
|
169
|
+
self.broker.publish("dplt_start_drawing_request")
|
|
170
|
+
self.update_output("Use your mouse pointer to draw a path on the figure")
|
|
171
|
+
else:
|
|
172
|
+
self.start_drawing_button.disabled = False
|
|
173
|
+
self.start_drawing_button.button_style = "primary"
|
|
174
|
+
self.start_drawing_button.description = "Start Drawing"
|
|
175
|
+
|
|
176
|
+
self.update_buttons_state()
|
|
177
|
+
|
|
178
|
+
self.broker.publish("dplt_end_drawing_request")
|
|
179
|
+
self.update_output(
|
|
180
|
+
"Click on the **Run** button to fit a pseudotime curve to the data"
|
|
181
|
+
+ " points. Make sure to select a data representation before"
|
|
182
|
+
+ " running the analysis."
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def estimate_periodic_pseudotime_start(self, _: Button | None = None):
|
|
186
|
+
from ...tools.cellflow.pseudotime._pseudotime import estimate_periodic_pseudotime_start
|
|
187
|
+
|
|
188
|
+
time_key = self.variable_controls["key_added"].value
|
|
189
|
+
estimate_periodic_pseudotime_start(self.parent.dataset.adata, time_key=time_key)
|
|
190
|
+
self.broker.publish(
|
|
191
|
+
"dset_metadata_change", self.parent.dataset.metadata, time_key
|
|
192
|
+
)
|
|
193
|
+
self.estimate_start_time_button.button_style = "success"
|
|
194
|
+
|
|
195
|
+
def send_signal_plot(self, _: Button | None = None):
|
|
196
|
+
from ...tools.cellflow.utils.interpolate import NDBSpline
|
|
197
|
+
|
|
198
|
+
if self.plot_signal.description == "Plot Signal":
|
|
199
|
+
adata = self.parent.dataset.adata
|
|
200
|
+
use_rep = self.variable_controls["use_rep"].value
|
|
201
|
+
roughness = self.variable_controls["roughness"].value
|
|
202
|
+
min_snr = self.variable_controls["min_snr"].value
|
|
203
|
+
periodic = self.parent.dataset.adata.uns["drawn_path"]["path_is_closed"]
|
|
204
|
+
|
|
205
|
+
t_range = (0.0, 1.0)
|
|
206
|
+
tmin, tmax = t_range
|
|
207
|
+
|
|
208
|
+
t = adata.obs["drawn_path"].values
|
|
209
|
+
X = adata.obsm[use_rep]
|
|
210
|
+
|
|
211
|
+
df = pd.DataFrame(
|
|
212
|
+
X,
|
|
213
|
+
columns=[f"Dim {i + 1}" for i in range(X.shape[1])],
|
|
214
|
+
index=adata.obs_names,
|
|
215
|
+
)
|
|
216
|
+
df = df.join(self.parent.dataset.metadata)
|
|
217
|
+
df["index"] = df.index
|
|
218
|
+
|
|
219
|
+
t_mask = (tmin <= t) * (t <= tmax)
|
|
220
|
+
t = t[t_mask]
|
|
221
|
+
X = X[t_mask]
|
|
222
|
+
df = df.loc[t_mask]
|
|
223
|
+
|
|
224
|
+
max_dims = 16
|
|
225
|
+
ndims = min(X.shape[1], max_dims)
|
|
226
|
+
|
|
227
|
+
F = NDBSpline(t_range=t_range, periodic=periodic, roughness=roughness)
|
|
228
|
+
F.fit(t, X)
|
|
229
|
+
|
|
230
|
+
SNR: NDArray[floating] = F(t).var(axis=0) / X.var(axis=0)
|
|
231
|
+
SNR = SNR / SNR.max()
|
|
232
|
+
|
|
233
|
+
x = np.linspace(*t_range, 200)
|
|
234
|
+
Y = F(x)
|
|
235
|
+
|
|
236
|
+
rows = cols = int(np.ceil(np.sqrt(ndims)))
|
|
237
|
+
titles = [f"Dim {i + 1}. SNR: {SNR[i]:.2f}" for i in range(ndims)]
|
|
238
|
+
fig = make_subplots(
|
|
239
|
+
rows=rows,
|
|
240
|
+
cols=cols,
|
|
241
|
+
shared_xaxes=True,
|
|
242
|
+
shared_yaxes=False,
|
|
243
|
+
x_title="Drawn path",
|
|
244
|
+
y_title="Signal",
|
|
245
|
+
subplot_titles=titles,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
for i in range(ndims):
|
|
249
|
+
row = i // cols + 1
|
|
250
|
+
col = i % cols + 1
|
|
251
|
+
snr = SNR[i]
|
|
252
|
+
marker_color = "blue" if snr >= min_snr else "lightgray"
|
|
253
|
+
line_color = "red" if snr >= min_snr else "gray"
|
|
254
|
+
|
|
255
|
+
scatter = px.scatter(
|
|
256
|
+
df,
|
|
257
|
+
x="drawn_path",
|
|
258
|
+
y=f"Dim {i + 1}",
|
|
259
|
+
template="simple_white",
|
|
260
|
+
hover_name="index",
|
|
261
|
+
)
|
|
262
|
+
scatter.update_traces(marker=dict(size=5, color=marker_color))
|
|
263
|
+
|
|
264
|
+
for trace in scatter.data:
|
|
265
|
+
fig.add_trace(trace, row=row, col=col)
|
|
266
|
+
|
|
267
|
+
line = go.Scattergl(
|
|
268
|
+
x=x,
|
|
269
|
+
y=Y[:, i],
|
|
270
|
+
mode="lines",
|
|
271
|
+
line_color=line_color,
|
|
272
|
+
)
|
|
273
|
+
fig.add_trace(line, row=row, col=col)
|
|
274
|
+
|
|
275
|
+
fig.update_layout(showlegend=False, title=f"{use_rep} Signal Plot")
|
|
276
|
+
self.plot_signal.description = "Close Signal Plot"
|
|
277
|
+
self.plot_signal.button_style = "warning"
|
|
278
|
+
|
|
279
|
+
else:
|
|
280
|
+
fig = None
|
|
281
|
+
self.plot_signal.description = "Plot Signal"
|
|
282
|
+
self.plot_signal.button_style = "info"
|
|
283
|
+
|
|
284
|
+
self.broker.publish("dplt_plot_figure_request", figure=fig)
|
|
285
|
+
|
|
286
|
+
def close_signal_plot(self, *args, **kwargs):
|
|
287
|
+
self.plot_signal.description = "Plot Signal"
|
|
288
|
+
self.plot_signal.button_style = "info"
|
|
289
|
+
self.broker.publish("dplt_plot_figure_request", figure=None)
|
|
290
|
+
|
|
291
|
+
def function(
|
|
292
|
+
self,
|
|
293
|
+
use_rep: str,
|
|
294
|
+
roughness: float,
|
|
295
|
+
min_snr: float,
|
|
296
|
+
key_added: str,
|
|
297
|
+
**kwargs,
|
|
298
|
+
):
|
|
299
|
+
from ...tools.cellflow.pseudotime._pseudotime import pseudotime
|
|
300
|
+
|
|
301
|
+
self.plot_signal.description = "Plot Signal"
|
|
302
|
+
self.plot_signal.button_style = "info"
|
|
303
|
+
|
|
304
|
+
periodic = self.parent.dataset.adata.uns["drawn_path"]["path_is_closed"]
|
|
305
|
+
|
|
306
|
+
self.output.clear_output(wait=True)
|
|
307
|
+
with self.output:
|
|
308
|
+
pseudotime(
|
|
309
|
+
adata=self.parent.dataset.adata,
|
|
310
|
+
use_rep=use_rep,
|
|
311
|
+
t_key="drawn_path",
|
|
312
|
+
t_range=(0.0, 1.0),
|
|
313
|
+
min_snr=min_snr,
|
|
314
|
+
periodic=periodic,
|
|
315
|
+
method="splines",
|
|
316
|
+
roughness=roughness,
|
|
317
|
+
key_added=key_added,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
self.parent.dataset.clear_selected_rows()
|
|
321
|
+
self.broker.publish("ctrl_data_key_value_change_request", use_rep)
|
|
322
|
+
self.broker.publish(
|
|
323
|
+
"dset_metadata_change", self.parent.dataset.metadata, key_added
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
self.send_fitted_pseudotime_plot()
|
|
327
|
+
self.update_output("")
|
|
328
|
+
|
|
329
|
+
if periodic:
|
|
330
|
+
self.estimate_start_time_button.disabled = False
|
|
331
|
+
self.estimate_start_time_button.button_style = "primary"
|
|
332
|
+
self.estimate_start_time_button.layout.visibility = "visible"
|
|
333
|
+
self.estimate_start_time_button.layout.height = "28px"
|
|
334
|
+
|
|
335
|
+
def send_fitted_pseudotime_plot(self, *args, **kwargs):
|
|
336
|
+
use_rep = self.variable_controls["use_rep"].value
|
|
337
|
+
key_added = self.variable_controls["key_added"].value
|
|
338
|
+
|
|
339
|
+
t: NDArray = self.parent.dataset.adata.obs[key_added].values
|
|
340
|
+
t_mask = ~np.isnan(t)
|
|
341
|
+
t = t[t_mask]
|
|
342
|
+
|
|
343
|
+
X_path = self.parent.dataset.adata.obsm[f"{key_added}_path"]
|
|
344
|
+
data = self.parent.dataset.data.copy()
|
|
345
|
+
data.values[:] = X_path
|
|
346
|
+
data: pd.DataFrame = data.loc[t_mask]
|
|
347
|
+
data = data.iloc[t.argsort()]
|
|
348
|
+
self.broker.publish(
|
|
349
|
+
"dplt_add_data_as_line_trace_request",
|
|
350
|
+
use_rep,
|
|
351
|
+
data,
|
|
352
|
+
name=key_added,
|
|
353
|
+
line_color="red",
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def _assign_drawn_path_values(self, *args, **kwargs):
|
|
357
|
+
dataset = self.parent.dataset
|
|
358
|
+
|
|
359
|
+
drawn_path = self.drawn_path.copy()
|
|
360
|
+
drawn_path_residue = self.drawn_path_residue.copy()
|
|
361
|
+
|
|
362
|
+
residue_threshold = self.variable_controls["residue_threshold"].value
|
|
363
|
+
x = drawn_path_residue / drawn_path_residue.max()
|
|
364
|
+
drawn_path.loc[x > residue_threshold] = np.nan
|
|
365
|
+
|
|
366
|
+
# detecting outliers: points with projection to the curve, but that are not
|
|
367
|
+
# part of the cluster of points where the user drew the path
|
|
368
|
+
x = drawn_path_residue.loc[drawn_path.notna()]
|
|
369
|
+
x = x[~np.isnan(x)].values
|
|
370
|
+
# sort in descending order
|
|
371
|
+
x = np.sort(x)[::-1]
|
|
372
|
+
# normalize
|
|
373
|
+
y = x / x.max()
|
|
374
|
+
# detect jumps in the normalized values
|
|
375
|
+
d = y[:-1] - y[1:]
|
|
376
|
+
if (d > 0.25).any():
|
|
377
|
+
# if there is a spike in the residue values, and the spike is larger than
|
|
378
|
+
# 25% of the maximum residue value, then remove the points with residue
|
|
379
|
+
# values larger than the threshold
|
|
380
|
+
thr = x[:-1][d > 0.25].min()
|
|
381
|
+
drawn_path.loc[drawn_path_residue >= thr] = np.nan
|
|
382
|
+
|
|
383
|
+
selected: NDArray = drawn_path.notna().values
|
|
384
|
+
if selected.any():
|
|
385
|
+
dataset.selected_rows = dataset.row_names[selected]
|
|
386
|
+
else:
|
|
387
|
+
dataset.selected_rows = None
|
|
388
|
+
|
|
389
|
+
publish_change = "drawn_path" not in dataset._metadata
|
|
390
|
+
dataset._metadata["drawn_path"] = drawn_path.clip(0, 1)
|
|
391
|
+
if publish_change:
|
|
392
|
+
self.broker.publish("dset_metadata_change", dataset.metadata)
|
|
393
|
+
|
|
394
|
+
def _automatic_periodic_path_drawing(self, *args, **kwargs):
|
|
395
|
+
from ...tools.cellflow.pseudotime._pseudotime import periodic_parameter
|
|
396
|
+
from ...tools.cellflow.utils.interpolate import NDFourier
|
|
397
|
+
|
|
398
|
+
data_points_array = self.parent.plotter.data_for_plot.values[:, :2]
|
|
399
|
+
ordr_points_array = periodic_parameter(data_points_array) / _2PI
|
|
400
|
+
F = NDFourier(t_range=(0, 1), grid_size=128, smoothing_fn=np.median)
|
|
401
|
+
F.fit(ordr_points_array, data_points_array)
|
|
402
|
+
|
|
403
|
+
T = time_points = np.linspace(0, 1, 1024 + 1)
|
|
404
|
+
path_points_array = F(time_points)
|
|
405
|
+
|
|
406
|
+
X = data_points_array[:, 0] + 1j * data_points_array[:, 1]
|
|
407
|
+
P = path_points_array[:, 0] + 1j * path_points_array[:, 1]
|
|
408
|
+
self._compute_drawn_path(T, X, P, path_is_closed=True)
|
|
409
|
+
self._assign_drawn_path_values()
|
|
410
|
+
line = go.Scattergl(x=P.real, y=P.imag, mode="lines", line_color="black")
|
|
411
|
+
self.broker.publish("dplt_add_trace_request", trace=line)
|
|
412
|
+
|
|
413
|
+
def _compute_drawn_path(
|
|
414
|
+
self,
|
|
415
|
+
time_points: NDArray[floating],
|
|
416
|
+
data_points: NDArray[floating],
|
|
417
|
+
path_points: NDArray[floating],
|
|
418
|
+
path_is_closed: bool,
|
|
419
|
+
):
|
|
420
|
+
dataset = self.parent.dataset
|
|
421
|
+
|
|
422
|
+
T = time_points
|
|
423
|
+
X = data_points
|
|
424
|
+
P = path_points
|
|
425
|
+
|
|
426
|
+
idxs = np.sort(
|
|
427
|
+
np.argsort(np.abs(X[:, None] - P[None, :]), axis=1)[:, :3], axis=1
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
drawn_path = pd.Series(index=dataset._metadata.index, dtype=float)
|
|
431
|
+
drawn_path_residue = pd.Series(index=dataset._metadata.index, dtype=float)
|
|
432
|
+
|
|
433
|
+
T1, T2, T3 = T[idxs].T
|
|
434
|
+
P1, P2, P3 = P[idxs].T
|
|
435
|
+
|
|
436
|
+
A = P2 - P1
|
|
437
|
+
B = X - P1
|
|
438
|
+
C = P2 - X
|
|
439
|
+
d = B.real * A.real + B.imag * A.imag
|
|
440
|
+
e = A.real * C.real + A.imag * C.imag
|
|
441
|
+
gap = np.abs(B - d * A / np.abs(A) ** 2)
|
|
442
|
+
m = mask1 = (d > 0) * (e > 0)
|
|
443
|
+
pseudotime = T1[m] + d[m] / np.abs(A[m]) * (T2[m] - T1[m])
|
|
444
|
+
drawn_path.loc[m] = pseudotime
|
|
445
|
+
drawn_path.loc[m] = pseudotime
|
|
446
|
+
drawn_path_residue.loc[m] = gap[m]
|
|
447
|
+
|
|
448
|
+
if (~mask1).sum() > 0:
|
|
449
|
+
A = P3[~mask1] - P1[~mask1]
|
|
450
|
+
B = X[~mask1] - P1[~mask1]
|
|
451
|
+
C = P3[~mask1] - X[~mask1]
|
|
452
|
+
d = B.real * A.real + B.imag * A.imag
|
|
453
|
+
e = A.real * C.real + A.imag * C.imag
|
|
454
|
+
gap = np.abs(B - d * A / np.abs(A) ** 2)
|
|
455
|
+
m = mask2 = np.zeros_like(mask1)
|
|
456
|
+
mask2[~mask1] = submsk = (d > 0) * (e > 0)
|
|
457
|
+
pseudotime = T1[m] + d[submsk] / np.abs(A[submsk]) * (T3[m] - T1[m])
|
|
458
|
+
drawn_path.loc[m] = pseudotime
|
|
459
|
+
drawn_path_residue.loc[m] = gap[submsk]
|
|
460
|
+
|
|
461
|
+
self.drawn_path = drawn_path.copy()
|
|
462
|
+
self.drawn_path_residue = drawn_path_residue.copy()
|
|
463
|
+
|
|
464
|
+
dataset.adata.uns["drawn_path"] = dict(
|
|
465
|
+
t_range=[0.0, 1.0],
|
|
466
|
+
periodic=path_is_closed,
|
|
467
|
+
path_is_closed=path_is_closed,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def dplt_soft_path_computed_callback(
|
|
471
|
+
self,
|
|
472
|
+
time_points: NDArray[floating],
|
|
473
|
+
data_points: NDArray[floating],
|
|
474
|
+
path_points: NDArray[floating],
|
|
475
|
+
path_is_closed: bool,
|
|
476
|
+
):
|
|
477
|
+
self._compute_drawn_path(time_points, data_points, path_points, path_is_closed)
|
|
478
|
+
self._assign_drawn_path_values()
|
|
479
|
+
line = go.Scattergl(
|
|
480
|
+
x=path_points.real, y=path_points.imag, mode="lines", line_color="black"
|
|
481
|
+
)
|
|
482
|
+
self.broker.publish("dplt_add_trace_request", trace=line)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from ipywidgets import (
|
|
5
|
+
Checkbox,
|
|
6
|
+
Dropdown,
|
|
7
|
+
FloatText,
|
|
8
|
+
IntText,
|
|
9
|
+
)
|
|
10
|
+
from pandas.api.types import is_numeric_dtype
|
|
11
|
+
|
|
12
|
+
from sclab.dataset.processor import Processor
|
|
13
|
+
from sclab.dataset.processor.step import ProcessorStepBase
|
|
14
|
+
|
|
15
|
+
_2PI = 2 * np.pi
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TransferMetadata(ProcessorStepBase):
|
|
19
|
+
parent: Processor
|
|
20
|
+
name: str = "transfer_metadata"
|
|
21
|
+
description: str = "Transfer Metadata"
|
|
22
|
+
|
|
23
|
+
run_button_description = "Transfer Metadata"
|
|
24
|
+
|
|
25
|
+
def __init__(self, parent: Processor) -> None:
|
|
26
|
+
variable_controls = dict(
|
|
27
|
+
group_key=Dropdown(
|
|
28
|
+
options=[],
|
|
29
|
+
value=None,
|
|
30
|
+
description="Group Key",
|
|
31
|
+
),
|
|
32
|
+
source_group=Dropdown(
|
|
33
|
+
options=[],
|
|
34
|
+
value=None,
|
|
35
|
+
description="Source Group",
|
|
36
|
+
),
|
|
37
|
+
column=Dropdown(
|
|
38
|
+
options=[],
|
|
39
|
+
value=None,
|
|
40
|
+
description="Column",
|
|
41
|
+
),
|
|
42
|
+
periodic=Checkbox(
|
|
43
|
+
value=False,
|
|
44
|
+
description="Periodic",
|
|
45
|
+
),
|
|
46
|
+
vmin=FloatText(
|
|
47
|
+
value=0,
|
|
48
|
+
description="Vmin",
|
|
49
|
+
continuous_update=False,
|
|
50
|
+
),
|
|
51
|
+
vmax=FloatText(
|
|
52
|
+
value=1,
|
|
53
|
+
description="Vmax",
|
|
54
|
+
continuous_update=False,
|
|
55
|
+
),
|
|
56
|
+
min_neighs=IntText(
|
|
57
|
+
value=5,
|
|
58
|
+
min=3,
|
|
59
|
+
description="Min Neighs",
|
|
60
|
+
continuous_update=False,
|
|
61
|
+
),
|
|
62
|
+
weight_by=Dropdown(
|
|
63
|
+
options=["connectivity", "distance", "constant"],
|
|
64
|
+
value="connectivity",
|
|
65
|
+
description="Weight By",
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
super().__init__(
|
|
70
|
+
parent=parent,
|
|
71
|
+
fixed_params={},
|
|
72
|
+
variable_controls=variable_controls,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self._update_groupby_options()
|
|
76
|
+
self._update_column_options()
|
|
77
|
+
self._update_numeric_column_controls()
|
|
78
|
+
|
|
79
|
+
self.variable_controls["group_key"].observe(
|
|
80
|
+
self._update_source_group_options, "value", "change"
|
|
81
|
+
)
|
|
82
|
+
self.variable_controls["column"].observe(
|
|
83
|
+
self._update_numeric_column_controls, "value", "change"
|
|
84
|
+
)
|
|
85
|
+
self.variable_controls["periodic"].observe(
|
|
86
|
+
self._update_vmin_vmax_visibility, "value", "change"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def _update_groupby_options(self, *args, **kwargs):
|
|
90
|
+
metadata = self.parent.dataset._metadata.select_dtypes(include=["category"])
|
|
91
|
+
options = {"": None, **{c: c for c in metadata.columns}}
|
|
92
|
+
self.variable_controls["group_key"].options = options
|
|
93
|
+
|
|
94
|
+
def _update_source_group_options(self, *args, **kwargs):
|
|
95
|
+
group_key = self.variable_controls["group_key"].value
|
|
96
|
+
if group_key is None:
|
|
97
|
+
self.variable_controls["source_group"].options = ("",)
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
options = self.parent.dataset._metadata[group_key].sort_values().unique()
|
|
101
|
+
options = {"": None, **{c: c for c in options}}
|
|
102
|
+
self.variable_controls["source_group"].options = options
|
|
103
|
+
|
|
104
|
+
def _update_column_options(self, *args, **kwargs):
|
|
105
|
+
metadata = self.parent.dataset._metadata.select_dtypes(
|
|
106
|
+
include=["category", "bool", "number"]
|
|
107
|
+
)
|
|
108
|
+
options = {"": None, **{c: c for c in metadata.columns}}
|
|
109
|
+
self.variable_controls["column"].options = options
|
|
110
|
+
|
|
111
|
+
def _update_numeric_column_controls(self, *args, **kwargs):
|
|
112
|
+
column = self.variable_controls["column"].value
|
|
113
|
+
if column is None:
|
|
114
|
+
self._hide_control(self.variable_controls["periodic"])
|
|
115
|
+
self._hide_control(self.variable_controls["vmin"])
|
|
116
|
+
self._hide_control(self.variable_controls["vmax"])
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
series = self.parent.dataset._metadata[column]
|
|
120
|
+
periodic = self.variable_controls["periodic"].value
|
|
121
|
+
|
|
122
|
+
if is_numeric_dtype(series):
|
|
123
|
+
self._show_control(self.variable_controls["periodic"])
|
|
124
|
+
if periodic:
|
|
125
|
+
self._show_control(self.variable_controls["vmin"])
|
|
126
|
+
self._show_control(self.variable_controls["vmax"])
|
|
127
|
+
else:
|
|
128
|
+
self._hide_control(self.variable_controls["periodic"])
|
|
129
|
+
self._hide_control(self.variable_controls["vmin"])
|
|
130
|
+
self._hide_control(self.variable_controls["vmax"])
|
|
131
|
+
|
|
132
|
+
def _update_vmin_vmax_visibility(self, *args, **kwargs):
|
|
133
|
+
periodic = self.variable_controls["periodic"].value
|
|
134
|
+
|
|
135
|
+
if periodic:
|
|
136
|
+
self._show_control(self.variable_controls["vmin"])
|
|
137
|
+
self._show_control(self.variable_controls["vmax"])
|
|
138
|
+
else:
|
|
139
|
+
self._hide_control(self.variable_controls["vmin"])
|
|
140
|
+
self._hide_control(self.variable_controls["vmax"])
|
|
141
|
+
|
|
142
|
+
def _hide_control(self, control):
|
|
143
|
+
control.layout.visibility = "hidden"
|
|
144
|
+
control.layout.height = "0px"
|
|
145
|
+
|
|
146
|
+
def _show_control(self, control):
|
|
147
|
+
control.layout.visibility = "visible"
|
|
148
|
+
control.layout.height = "28px"
|
|
149
|
+
|
|
150
|
+
def function(
|
|
151
|
+
self,
|
|
152
|
+
group_key: str,
|
|
153
|
+
source_group: str,
|
|
154
|
+
column: str,
|
|
155
|
+
periodic: bool = False,
|
|
156
|
+
vmin: float = 0,
|
|
157
|
+
vmax: float = 1,
|
|
158
|
+
min_neighs: int = 5,
|
|
159
|
+
weight_by: Literal["connectivity", "distance", "constant"] = "connectivity",
|
|
160
|
+
**kwargs,
|
|
161
|
+
):
|
|
162
|
+
from ...preprocess._transfer_metadata import transfer_metadata
|
|
163
|
+
|
|
164
|
+
self.output.clear_output(wait=True)
|
|
165
|
+
with self.output:
|
|
166
|
+
transfer_metadata(
|
|
167
|
+
self.parent.dataset.adata,
|
|
168
|
+
group_key=group_key,
|
|
169
|
+
source_group=source_group,
|
|
170
|
+
column=column,
|
|
171
|
+
periodic=periodic,
|
|
172
|
+
vmin=vmin,
|
|
173
|
+
vmax=vmax,
|
|
174
|
+
min_neighs=min_neighs,
|
|
175
|
+
weight_by=weight_by,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
new_column = f"transferred_{column}"
|
|
179
|
+
|
|
180
|
+
self.broker.publish(
|
|
181
|
+
"dset_metadata_change", self.parent.dataset.metadata, new_column
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def dset_metadata_change_callback(self, *args, **kwargs):
|
|
185
|
+
self._update_groupby_options(*args, **kwargs)
|
|
186
|
+
self._update_column_options(*args, **kwargs)
|