sclab 0.1.7__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sclab/__init__.py +3 -1
- sclab/_io.py +83 -12
- sclab/_methods_registry.py +65 -0
- sclab/_sclab.py +241 -21
- sclab/dataset/_dataset.py +4 -6
- sclab/dataset/processor/_processor.py +41 -19
- sclab/dataset/processor/_results_panel.py +94 -0
- sclab/dataset/processor/step/_processor_step_base.py +12 -6
- sclab/examples/processor_steps/__init__.py +8 -0
- sclab/examples/processor_steps/_cluster.py +2 -2
- sclab/examples/processor_steps/_differential_expression.py +329 -0
- sclab/examples/processor_steps/_doublet_detection.py +68 -0
- sclab/examples/processor_steps/_gene_expression.py +125 -0
- sclab/examples/processor_steps/_integration.py +116 -0
- sclab/examples/processor_steps/_neighbors.py +26 -6
- sclab/examples/processor_steps/_pca.py +13 -8
- sclab/examples/processor_steps/_preprocess.py +52 -25
- sclab/examples/processor_steps/_qc.py +24 -8
- sclab/examples/processor_steps/_umap.py +2 -2
- sclab/gui/__init__.py +0 -0
- sclab/gui/components/__init__.py +7 -0
- sclab/gui/components/_guided_pseudotime.py +482 -0
- sclab/gui/components/_transfer_metadata.py +186 -0
- sclab/methods/__init__.py +50 -0
- sclab/preprocess/__init__.py +26 -0
- sclab/preprocess/_cca.py +176 -0
- sclab/preprocess/_cca_integrate.py +109 -0
- sclab/preprocess/_filter_obs.py +42 -0
- sclab/preprocess/_harmony.py +421 -0
- sclab/preprocess/_harmony_integrate.py +53 -0
- sclab/preprocess/_normalize_weighted.py +65 -0
- sclab/preprocess/_pca.py +51 -0
- sclab/preprocess/_preprocess.py +155 -0
- sclab/preprocess/_qc.py +38 -0
- sclab/preprocess/_rpca.py +116 -0
- sclab/preprocess/_subset.py +208 -0
- sclab/preprocess/_transfer_metadata.py +196 -0
- sclab/preprocess/_transform.py +82 -0
- sclab/preprocess/_utils.py +96 -0
- sclab/scanpy/__init__.py +0 -0
- sclab/scanpy/_compat.py +92 -0
- sclab/scanpy/_settings.py +526 -0
- sclab/scanpy/logging.py +290 -0
- sclab/scanpy/plotting/__init__.py +0 -0
- sclab/scanpy/plotting/_rcmod.py +73 -0
- sclab/scanpy/plotting/palettes.py +221 -0
- sclab/scanpy/readwrite.py +1108 -0
- sclab/tools/__init__.py +0 -0
- sclab/tools/cellflow/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
- sclab/tools/cellflow/pseudotime/__init__.py +0 -0
- sclab/tools/cellflow/pseudotime/_pseudotime.py +336 -0
- sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
- sclab/tools/cellflow/utils/__init__.py +0 -0
- sclab/tools/cellflow/utils/density_nd.py +215 -0
- sclab/tools/cellflow/utils/interpolate.py +334 -0
- sclab/tools/cellflow/utils/periodic_genes.py +106 -0
- sclab/tools/cellflow/utils/smoothen.py +124 -0
- sclab/tools/cellflow/utils/times.py +55 -0
- sclab/tools/differential_expression/__init__.py +7 -0
- sclab/tools/differential_expression/_pseudobulk_edger.py +309 -0
- sclab/tools/differential_expression/_pseudobulk_helpers.py +290 -0
- sclab/tools/differential_expression/_pseudobulk_limma.py +257 -0
- sclab/tools/doublet_detection/__init__.py +5 -0
- sclab/tools/doublet_detection/_scrublet.py +64 -0
- sclab/tools/embedding/__init__.py +0 -0
- sclab/tools/imputation/__init__.py +0 -0
- sclab/tools/imputation/_alra.py +135 -0
- sclab/tools/labeling/__init__.py +6 -0
- sclab/tools/labeling/sctype.py +233 -0
- sclab/tools/utils/__init__.py +5 -0
- sclab/tools/utils/_aggregate_and_filter.py +290 -0
- sclab/utils/__init__.py +5 -0
- sclab/utils/_write_excel.py +510 -0
- {sclab-0.1.7.dist-info → sclab-0.3.4.dist-info}/METADATA +29 -12
- sclab-0.3.4.dist-info/RECORD +93 -0
- {sclab-0.1.7.dist-info → sclab-0.3.4.dist-info}/WHEEL +1 -1
- sclab-0.3.4.dist-info/licenses/LICENSE +29 -0
- sclab-0.1.7.dist-info/RECORD +0 -30
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import plotly.express as px
|
|
4
|
+
import plotly.graph_objects as go
|
|
5
|
+
from ipywidgets import Button, Dropdown, FloatLogSlider, FloatSlider, HBox, Text
|
|
6
|
+
from numpy import floating
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from plotly.subplots import make_subplots
|
|
9
|
+
from sclab.dataset.processor import Processor
|
|
10
|
+
from sclab.dataset.processor.step import ProcessorStepBase
|
|
11
|
+
|
|
12
|
+
# TODO: remove self.drawn_path and self.drawn_path_residue from the class
|
|
13
|
+
# and add them to the dataset as _drawn_path and _drawn_path_residue
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_2PI = 2 * np.pi
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GuidedPseudotime(ProcessorStepBase):
|
|
20
|
+
parent: Processor
|
|
21
|
+
name: str = "guided_pseudotime"
|
|
22
|
+
description: str = "Guided Pseudotime"
|
|
23
|
+
|
|
24
|
+
run_button_description = "Compute Pseudotime"
|
|
25
|
+
|
|
26
|
+
def __init__(self, parent: Processor) -> None:
|
|
27
|
+
variable_controls = dict(
|
|
28
|
+
residue_threshold=FloatLogSlider(
|
|
29
|
+
value=1,
|
|
30
|
+
min=-3,
|
|
31
|
+
max=0,
|
|
32
|
+
description="Filter by Dist.",
|
|
33
|
+
continuous_update=True,
|
|
34
|
+
),
|
|
35
|
+
use_rep=Dropdown(
|
|
36
|
+
options=tuple(parent.dataset.adata.obsm.keys()),
|
|
37
|
+
value=None,
|
|
38
|
+
description="Use rep.",
|
|
39
|
+
),
|
|
40
|
+
roughness=FloatSlider(
|
|
41
|
+
value=0,
|
|
42
|
+
min=0,
|
|
43
|
+
max=3,
|
|
44
|
+
step=0.05,
|
|
45
|
+
description="Roughness",
|
|
46
|
+
continuous_update=False,
|
|
47
|
+
),
|
|
48
|
+
min_snr=FloatSlider(
|
|
49
|
+
value=0.25,
|
|
50
|
+
min=0,
|
|
51
|
+
max=1,
|
|
52
|
+
step=0.05,
|
|
53
|
+
description="SNR",
|
|
54
|
+
continuous_update=False,
|
|
55
|
+
),
|
|
56
|
+
key_added=Text(
|
|
57
|
+
value="pseudotime",
|
|
58
|
+
description="Key added",
|
|
59
|
+
placeholder="",
|
|
60
|
+
),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
super().__init__(
|
|
64
|
+
parent=parent,
|
|
65
|
+
fixed_params={},
|
|
66
|
+
variable_controls=variable_controls,
|
|
67
|
+
)
|
|
68
|
+
self.run_button.disabled = True
|
|
69
|
+
self.run_button.button_style = ""
|
|
70
|
+
self.estimate_start_time_button.layout.visibility = "hidden"
|
|
71
|
+
self.estimate_start_time_button.layout.height = "0px"
|
|
72
|
+
|
|
73
|
+
self.start_drawing_button.on_click(self.toggle_drawing_callback)
|
|
74
|
+
self.plot_signal.on_click(self.send_signal_plot)
|
|
75
|
+
self.estimate_start_time_button.on_click(
|
|
76
|
+
self.estimate_periodic_pseudotime_start
|
|
77
|
+
)
|
|
78
|
+
self.plot_fitted_pseudotime_curve.on_click(self.send_fitted_pseudotime_plot)
|
|
79
|
+
self.variable_controls["use_rep"].observe(
|
|
80
|
+
self.update_buttons_state, names="value", type="change"
|
|
81
|
+
)
|
|
82
|
+
self.variable_controls["residue_threshold"].observe(
|
|
83
|
+
self._assign_drawn_path_values, names="value", type="change"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.variable_controls["use_rep"].observe(
|
|
87
|
+
self.close_signal_plot, names="value", type="change"
|
|
88
|
+
)
|
|
89
|
+
self.variable_controls["roughness"].observe(
|
|
90
|
+
self.close_signal_plot, names="value", type="change"
|
|
91
|
+
)
|
|
92
|
+
self.variable_controls["min_snr"].observe(
|
|
93
|
+
self.close_signal_plot, names="value", type="change"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def make_controls(self):
|
|
97
|
+
self.start_drawing_button = Button(
|
|
98
|
+
description="Start Drawing", button_style="primary"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
self.auto_drawing_button = Button(
|
|
102
|
+
description="Automatic Drawing", button_style="primary"
|
|
103
|
+
)
|
|
104
|
+
self.auto_drawing_button.on_click(self._automatic_periodic_path_drawing)
|
|
105
|
+
|
|
106
|
+
self.plot_signal = Button(
|
|
107
|
+
description="Plot Signal", button_style="info", disabled=True
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
self.estimate_start_time_button = Button(
|
|
111
|
+
description="Estimate Pseudotime Start",
|
|
112
|
+
button_style="",
|
|
113
|
+
disabled=True,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.plot_fitted_pseudotime_curve = Button(
|
|
117
|
+
description="Plot Fitted Pseudotime Curve",
|
|
118
|
+
button_style="info",
|
|
119
|
+
disabled=True,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.controls_list = [
|
|
123
|
+
HBox([self.auto_drawing_button, self.start_drawing_button]),
|
|
124
|
+
*self.variable_controls.values(),
|
|
125
|
+
self.plot_signal,
|
|
126
|
+
self.run_button,
|
|
127
|
+
self.plot_fitted_pseudotime_curve,
|
|
128
|
+
self.estimate_start_time_button,
|
|
129
|
+
self.output,
|
|
130
|
+
]
|
|
131
|
+
super().make_controls()
|
|
132
|
+
|
|
133
|
+
def update_buttons_state(self, *args, **kwargs):
|
|
134
|
+
drawing = self.start_drawing_button.description != "Start Drawing"
|
|
135
|
+
if self.variable_controls["use_rep"].value is None or drawing:
|
|
136
|
+
self.run_button.disabled = True
|
|
137
|
+
self.run_button.button_style = ""
|
|
138
|
+
|
|
139
|
+
self.plot_signal.disabled = True
|
|
140
|
+
self.plot_signal.button_style = ""
|
|
141
|
+
|
|
142
|
+
self.plot_fitted_pseudotime_curve.disabled = True
|
|
143
|
+
self.plot_fitted_pseudotime_curve.button_style = ""
|
|
144
|
+
|
|
145
|
+
self.estimate_start_time_button.disabled = True
|
|
146
|
+
self.estimate_start_time_button.button_style = ""
|
|
147
|
+
self.estimate_start_time_button.layout.visibility = "hidden"
|
|
148
|
+
self.estimate_start_time_button.layout.height = "0px"
|
|
149
|
+
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
self.run_button.disabled = False
|
|
153
|
+
self.run_button.button_style = "primary"
|
|
154
|
+
|
|
155
|
+
self.plot_signal.disabled = False
|
|
156
|
+
self.plot_signal.button_style = "info"
|
|
157
|
+
|
|
158
|
+
self.plot_fitted_pseudotime_curve.disabled = False
|
|
159
|
+
self.plot_fitted_pseudotime_curve.button_style = "info"
|
|
160
|
+
|
|
161
|
+
def toggle_drawing_callback(self, _: Button | None = None):
|
|
162
|
+
if self.start_drawing_button.description == "Start Drawing":
|
|
163
|
+
self.start_drawing_button.disabled = False
|
|
164
|
+
self.start_drawing_button.button_style = "warning"
|
|
165
|
+
self.start_drawing_button.description = "--> click here when ready <--"
|
|
166
|
+
|
|
167
|
+
self.update_buttons_state()
|
|
168
|
+
|
|
169
|
+
self.broker.publish("dplt_start_drawing_request")
|
|
170
|
+
self.update_output("Use your mouse pointer to draw a path on the figure")
|
|
171
|
+
else:
|
|
172
|
+
self.start_drawing_button.disabled = False
|
|
173
|
+
self.start_drawing_button.button_style = "primary"
|
|
174
|
+
self.start_drawing_button.description = "Start Drawing"
|
|
175
|
+
|
|
176
|
+
self.update_buttons_state()
|
|
177
|
+
|
|
178
|
+
self.broker.publish("dplt_end_drawing_request")
|
|
179
|
+
self.update_output(
|
|
180
|
+
"Click on the **Run** button to fit a pseudotime curve to the data"
|
|
181
|
+
+ " points. Make sure to select a data representation before"
|
|
182
|
+
+ " running the analysis."
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def estimate_periodic_pseudotime_start(self, _: Button | None = None):
|
|
186
|
+
from ...tools.cellflow.pseudotime._pseudotime import estimate_periodic_pseudotime_start
|
|
187
|
+
|
|
188
|
+
time_key = self.variable_controls["key_added"].value
|
|
189
|
+
estimate_periodic_pseudotime_start(self.parent.dataset.adata, time_key=time_key)
|
|
190
|
+
self.broker.publish(
|
|
191
|
+
"dset_metadata_change", self.parent.dataset.metadata, time_key
|
|
192
|
+
)
|
|
193
|
+
self.estimate_start_time_button.button_style = "success"
|
|
194
|
+
|
|
195
|
+
def send_signal_plot(self, _: Button | None = None):
|
|
196
|
+
from ...tools.cellflow.utils.interpolate import NDBSpline
|
|
197
|
+
|
|
198
|
+
if self.plot_signal.description == "Plot Signal":
|
|
199
|
+
adata = self.parent.dataset.adata
|
|
200
|
+
use_rep = self.variable_controls["use_rep"].value
|
|
201
|
+
roughness = self.variable_controls["roughness"].value
|
|
202
|
+
min_snr = self.variable_controls["min_snr"].value
|
|
203
|
+
periodic = self.parent.dataset.adata.uns["drawn_path"]["path_is_closed"]
|
|
204
|
+
|
|
205
|
+
t_range = (0.0, 1.0)
|
|
206
|
+
tmin, tmax = t_range
|
|
207
|
+
|
|
208
|
+
t = adata.obs["drawn_path"].values
|
|
209
|
+
X = adata.obsm[use_rep]
|
|
210
|
+
|
|
211
|
+
df = pd.DataFrame(
|
|
212
|
+
X,
|
|
213
|
+
columns=[f"Dim {i + 1}" for i in range(X.shape[1])],
|
|
214
|
+
index=adata.obs_names,
|
|
215
|
+
)
|
|
216
|
+
df = df.join(self.parent.dataset.metadata)
|
|
217
|
+
df["index"] = df.index
|
|
218
|
+
|
|
219
|
+
t_mask = (tmin <= t) * (t <= tmax)
|
|
220
|
+
t = t[t_mask]
|
|
221
|
+
X = X[t_mask]
|
|
222
|
+
df = df.loc[t_mask]
|
|
223
|
+
|
|
224
|
+
max_dims = 16
|
|
225
|
+
ndims = min(X.shape[1], max_dims)
|
|
226
|
+
|
|
227
|
+
F = NDBSpline(t_range=t_range, periodic=periodic, roughness=roughness)
|
|
228
|
+
F.fit(t, X)
|
|
229
|
+
|
|
230
|
+
SNR: NDArray[floating] = F(t).var(axis=0) / X.var(axis=0)
|
|
231
|
+
SNR = SNR / SNR.max()
|
|
232
|
+
|
|
233
|
+
x = np.linspace(*t_range, 200)
|
|
234
|
+
Y = F(x)
|
|
235
|
+
|
|
236
|
+
rows = cols = int(np.ceil(np.sqrt(ndims)))
|
|
237
|
+
titles = [f"Dim {i + 1}. SNR: {SNR[i]:.2f}" for i in range(ndims)]
|
|
238
|
+
fig = make_subplots(
|
|
239
|
+
rows=rows,
|
|
240
|
+
cols=cols,
|
|
241
|
+
shared_xaxes=True,
|
|
242
|
+
shared_yaxes=False,
|
|
243
|
+
x_title="Drawn path",
|
|
244
|
+
y_title="Signal",
|
|
245
|
+
subplot_titles=titles,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
for i in range(ndims):
|
|
249
|
+
row = i // cols + 1
|
|
250
|
+
col = i % cols + 1
|
|
251
|
+
snr = SNR[i]
|
|
252
|
+
marker_color = "blue" if snr >= min_snr else "lightgray"
|
|
253
|
+
line_color = "red" if snr >= min_snr else "gray"
|
|
254
|
+
|
|
255
|
+
scatter = px.scatter(
|
|
256
|
+
df,
|
|
257
|
+
x="drawn_path",
|
|
258
|
+
y=f"Dim {i + 1}",
|
|
259
|
+
template="simple_white",
|
|
260
|
+
hover_name="index",
|
|
261
|
+
)
|
|
262
|
+
scatter.update_traces(marker=dict(size=5, color=marker_color))
|
|
263
|
+
|
|
264
|
+
for trace in scatter.data:
|
|
265
|
+
fig.add_trace(trace, row=row, col=col)
|
|
266
|
+
|
|
267
|
+
line = go.Scattergl(
|
|
268
|
+
x=x,
|
|
269
|
+
y=Y[:, i],
|
|
270
|
+
mode="lines",
|
|
271
|
+
line_color=line_color,
|
|
272
|
+
)
|
|
273
|
+
fig.add_trace(line, row=row, col=col)
|
|
274
|
+
|
|
275
|
+
fig.update_layout(showlegend=False, title=f"{use_rep} Signal Plot")
|
|
276
|
+
self.plot_signal.description = "Close Signal Plot"
|
|
277
|
+
self.plot_signal.button_style = "warning"
|
|
278
|
+
|
|
279
|
+
else:
|
|
280
|
+
fig = None
|
|
281
|
+
self.plot_signal.description = "Plot Signal"
|
|
282
|
+
self.plot_signal.button_style = "info"
|
|
283
|
+
|
|
284
|
+
self.broker.publish("dplt_plot_figure_request", figure=fig)
|
|
285
|
+
|
|
286
|
+
def close_signal_plot(self, *args, **kwargs):
|
|
287
|
+
self.plot_signal.description = "Plot Signal"
|
|
288
|
+
self.plot_signal.button_style = "info"
|
|
289
|
+
self.broker.publish("dplt_plot_figure_request", figure=None)
|
|
290
|
+
|
|
291
|
+
def function(
|
|
292
|
+
self,
|
|
293
|
+
use_rep: str,
|
|
294
|
+
roughness: float,
|
|
295
|
+
min_snr: float,
|
|
296
|
+
key_added: str,
|
|
297
|
+
**kwargs,
|
|
298
|
+
):
|
|
299
|
+
from ...tools.cellflow.pseudotime._pseudotime import pseudotime
|
|
300
|
+
|
|
301
|
+
self.plot_signal.description = "Plot Signal"
|
|
302
|
+
self.plot_signal.button_style = "info"
|
|
303
|
+
|
|
304
|
+
periodic = self.parent.dataset.adata.uns["drawn_path"]["path_is_closed"]
|
|
305
|
+
|
|
306
|
+
self.output.clear_output(wait=True)
|
|
307
|
+
with self.output:
|
|
308
|
+
pseudotime(
|
|
309
|
+
adata=self.parent.dataset.adata,
|
|
310
|
+
use_rep=use_rep,
|
|
311
|
+
t_key="drawn_path",
|
|
312
|
+
t_range=(0.0, 1.0),
|
|
313
|
+
min_snr=min_snr,
|
|
314
|
+
periodic=periodic,
|
|
315
|
+
method="splines",
|
|
316
|
+
roughness=roughness,
|
|
317
|
+
key_added=key_added,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
self.parent.dataset.clear_selected_rows()
|
|
321
|
+
self.broker.publish("ctrl_data_key_value_change_request", use_rep)
|
|
322
|
+
self.broker.publish(
|
|
323
|
+
"dset_metadata_change", self.parent.dataset.metadata, key_added
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
self.send_fitted_pseudotime_plot()
|
|
327
|
+
self.update_output("")
|
|
328
|
+
|
|
329
|
+
if periodic:
|
|
330
|
+
self.estimate_start_time_button.disabled = False
|
|
331
|
+
self.estimate_start_time_button.button_style = "primary"
|
|
332
|
+
self.estimate_start_time_button.layout.visibility = "visible"
|
|
333
|
+
self.estimate_start_time_button.layout.height = "28px"
|
|
334
|
+
|
|
335
|
+
def send_fitted_pseudotime_plot(self, *args, **kwargs):
|
|
336
|
+
use_rep = self.variable_controls["use_rep"].value
|
|
337
|
+
key_added = self.variable_controls["key_added"].value
|
|
338
|
+
|
|
339
|
+
t: NDArray = self.parent.dataset.adata.obs[key_added].values
|
|
340
|
+
t_mask = ~np.isnan(t)
|
|
341
|
+
t = t[t_mask]
|
|
342
|
+
|
|
343
|
+
X_path = self.parent.dataset.adata.obsm[f"{key_added}_path"]
|
|
344
|
+
data = self.parent.dataset.data.copy()
|
|
345
|
+
data.values[:] = X_path
|
|
346
|
+
data: pd.DataFrame = data.loc[t_mask]
|
|
347
|
+
data = data.iloc[t.argsort()]
|
|
348
|
+
self.broker.publish(
|
|
349
|
+
"dplt_add_data_as_line_trace_request",
|
|
350
|
+
use_rep,
|
|
351
|
+
data,
|
|
352
|
+
name=key_added,
|
|
353
|
+
line_color="red",
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def _assign_drawn_path_values(self, *args, **kwargs):
|
|
357
|
+
dataset = self.parent.dataset
|
|
358
|
+
|
|
359
|
+
drawn_path = self.drawn_path.copy()
|
|
360
|
+
drawn_path_residue = self.drawn_path_residue.copy()
|
|
361
|
+
|
|
362
|
+
residue_threshold = self.variable_controls["residue_threshold"].value
|
|
363
|
+
x = drawn_path_residue / drawn_path_residue.max()
|
|
364
|
+
drawn_path.loc[x > residue_threshold] = np.nan
|
|
365
|
+
|
|
366
|
+
# detecting outliers: points with projection to the curve, but that are not
|
|
367
|
+
# part of the cluster of points where the user drew the path
|
|
368
|
+
x = drawn_path_residue.loc[drawn_path.notna()]
|
|
369
|
+
x = x[~np.isnan(x)].values
|
|
370
|
+
# sort in descending order
|
|
371
|
+
x = np.sort(x)[::-1]
|
|
372
|
+
# normalize
|
|
373
|
+
y = x / x.max()
|
|
374
|
+
# detect jumps in the normalized values
|
|
375
|
+
d = y[:-1] - y[1:]
|
|
376
|
+
if (d > 0.25).any():
|
|
377
|
+
# if there is a spike in the residue values, and the spike is larger than
|
|
378
|
+
# 25% of the maximum residue value, then remove the points with residue
|
|
379
|
+
# values larger than the threshold
|
|
380
|
+
thr = x[:-1][d > 0.25].min()
|
|
381
|
+
drawn_path.loc[drawn_path_residue >= thr] = np.nan
|
|
382
|
+
|
|
383
|
+
selected: NDArray = drawn_path.notna().values
|
|
384
|
+
if selected.any():
|
|
385
|
+
dataset.selected_rows = dataset.row_names[selected]
|
|
386
|
+
else:
|
|
387
|
+
dataset.selected_rows = None
|
|
388
|
+
|
|
389
|
+
publish_change = "drawn_path" not in dataset._metadata
|
|
390
|
+
dataset._metadata["drawn_path"] = drawn_path.clip(0, 1)
|
|
391
|
+
if publish_change:
|
|
392
|
+
self.broker.publish("dset_metadata_change", dataset.metadata)
|
|
393
|
+
|
|
394
|
+
def _automatic_periodic_path_drawing(self, *args, **kwargs):
|
|
395
|
+
from ...tools.cellflow.pseudotime._pseudotime import periodic_parameter
|
|
396
|
+
from ...tools.cellflow.utils.interpolate import NDFourier
|
|
397
|
+
|
|
398
|
+
data_points_array = self.parent.plotter.data_for_plot.values[:, :2]
|
|
399
|
+
ordr_points_array = periodic_parameter(data_points_array) / _2PI
|
|
400
|
+
F = NDFourier(t_range=(0, 1), grid_size=128, smoothing_fn=np.median)
|
|
401
|
+
F.fit(ordr_points_array, data_points_array)
|
|
402
|
+
|
|
403
|
+
T = time_points = np.linspace(0, 1, 1024 + 1)
|
|
404
|
+
path_points_array = F(time_points)
|
|
405
|
+
|
|
406
|
+
X = data_points_array[:, 0] + 1j * data_points_array[:, 1]
|
|
407
|
+
P = path_points_array[:, 0] + 1j * path_points_array[:, 1]
|
|
408
|
+
self._compute_drawn_path(T, X, P, path_is_closed=True)
|
|
409
|
+
self._assign_drawn_path_values()
|
|
410
|
+
line = go.Scattergl(x=P.real, y=P.imag, mode="lines", line_color="black")
|
|
411
|
+
self.broker.publish("dplt_add_trace_request", trace=line)
|
|
412
|
+
|
|
413
|
+
def _compute_drawn_path(
|
|
414
|
+
self,
|
|
415
|
+
time_points: NDArray[floating],
|
|
416
|
+
data_points: NDArray[floating],
|
|
417
|
+
path_points: NDArray[floating],
|
|
418
|
+
path_is_closed: bool,
|
|
419
|
+
):
|
|
420
|
+
dataset = self.parent.dataset
|
|
421
|
+
|
|
422
|
+
T = time_points
|
|
423
|
+
X = data_points
|
|
424
|
+
P = path_points
|
|
425
|
+
|
|
426
|
+
idxs = np.sort(
|
|
427
|
+
np.argsort(np.abs(X[:, None] - P[None, :]), axis=1)[:, :3], axis=1
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
drawn_path = pd.Series(index=dataset._metadata.index, dtype=float)
|
|
431
|
+
drawn_path_residue = pd.Series(index=dataset._metadata.index, dtype=float)
|
|
432
|
+
|
|
433
|
+
T1, T2, T3 = T[idxs].T
|
|
434
|
+
P1, P2, P3 = P[idxs].T
|
|
435
|
+
|
|
436
|
+
A = P2 - P1
|
|
437
|
+
B = X - P1
|
|
438
|
+
C = P2 - X
|
|
439
|
+
d = B.real * A.real + B.imag * A.imag
|
|
440
|
+
e = A.real * C.real + A.imag * C.imag
|
|
441
|
+
gap = np.abs(B - d * A / np.abs(A) ** 2)
|
|
442
|
+
m = mask1 = (d > 0) * (e > 0)
|
|
443
|
+
pseudotime = T1[m] + d[m] / np.abs(A[m]) * (T2[m] - T1[m])
|
|
444
|
+
drawn_path.loc[m] = pseudotime
|
|
445
|
+
drawn_path.loc[m] = pseudotime
|
|
446
|
+
drawn_path_residue.loc[m] = gap[m]
|
|
447
|
+
|
|
448
|
+
if (~mask1).sum() > 0:
|
|
449
|
+
A = P3[~mask1] - P1[~mask1]
|
|
450
|
+
B = X[~mask1] - P1[~mask1]
|
|
451
|
+
C = P3[~mask1] - X[~mask1]
|
|
452
|
+
d = B.real * A.real + B.imag * A.imag
|
|
453
|
+
e = A.real * C.real + A.imag * C.imag
|
|
454
|
+
gap = np.abs(B - d * A / np.abs(A) ** 2)
|
|
455
|
+
m = mask2 = np.zeros_like(mask1)
|
|
456
|
+
mask2[~mask1] = submsk = (d > 0) * (e > 0)
|
|
457
|
+
pseudotime = T1[m] + d[submsk] / np.abs(A[submsk]) * (T3[m] - T1[m])
|
|
458
|
+
drawn_path.loc[m] = pseudotime
|
|
459
|
+
drawn_path_residue.loc[m] = gap[submsk]
|
|
460
|
+
|
|
461
|
+
self.drawn_path = drawn_path.copy()
|
|
462
|
+
self.drawn_path_residue = drawn_path_residue.copy()
|
|
463
|
+
|
|
464
|
+
dataset.adata.uns["drawn_path"] = dict(
|
|
465
|
+
t_range=[0.0, 1.0],
|
|
466
|
+
periodic=path_is_closed,
|
|
467
|
+
path_is_closed=path_is_closed,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def dplt_soft_path_computed_callback(
|
|
471
|
+
self,
|
|
472
|
+
time_points: NDArray[floating],
|
|
473
|
+
data_points: NDArray[floating],
|
|
474
|
+
path_points: NDArray[floating],
|
|
475
|
+
path_is_closed: bool,
|
|
476
|
+
):
|
|
477
|
+
self._compute_drawn_path(time_points, data_points, path_points, path_is_closed)
|
|
478
|
+
self._assign_drawn_path_values()
|
|
479
|
+
line = go.Scattergl(
|
|
480
|
+
x=path_points.real, y=path_points.imag, mode="lines", line_color="black"
|
|
481
|
+
)
|
|
482
|
+
self.broker.publish("dplt_add_trace_request", trace=line)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from ipywidgets import (
|
|
5
|
+
Checkbox,
|
|
6
|
+
Dropdown,
|
|
7
|
+
FloatText,
|
|
8
|
+
IntText,
|
|
9
|
+
)
|
|
10
|
+
from pandas.api.types import is_numeric_dtype
|
|
11
|
+
|
|
12
|
+
from sclab.dataset.processor import Processor
|
|
13
|
+
from sclab.dataset.processor.step import ProcessorStepBase
|
|
14
|
+
|
|
15
|
+
_2PI = 2 * np.pi
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TransferMetadata(ProcessorStepBase):
|
|
19
|
+
parent: Processor
|
|
20
|
+
name: str = "transfer_metadata"
|
|
21
|
+
description: str = "Transfer Metadata"
|
|
22
|
+
|
|
23
|
+
run_button_description = "Transfer Metadata"
|
|
24
|
+
|
|
25
|
+
def __init__(self, parent: Processor) -> None:
|
|
26
|
+
variable_controls = dict(
|
|
27
|
+
group_key=Dropdown(
|
|
28
|
+
options=[],
|
|
29
|
+
value=None,
|
|
30
|
+
description="Group Key",
|
|
31
|
+
),
|
|
32
|
+
source_group=Dropdown(
|
|
33
|
+
options=[],
|
|
34
|
+
value=None,
|
|
35
|
+
description="Source Group",
|
|
36
|
+
),
|
|
37
|
+
column=Dropdown(
|
|
38
|
+
options=[],
|
|
39
|
+
value=None,
|
|
40
|
+
description="Column",
|
|
41
|
+
),
|
|
42
|
+
periodic=Checkbox(
|
|
43
|
+
value=False,
|
|
44
|
+
description="Periodic",
|
|
45
|
+
),
|
|
46
|
+
vmin=FloatText(
|
|
47
|
+
value=0,
|
|
48
|
+
description="Vmin",
|
|
49
|
+
continuous_update=False,
|
|
50
|
+
),
|
|
51
|
+
vmax=FloatText(
|
|
52
|
+
value=1,
|
|
53
|
+
description="Vmax",
|
|
54
|
+
continuous_update=False,
|
|
55
|
+
),
|
|
56
|
+
min_neighs=IntText(
|
|
57
|
+
value=5,
|
|
58
|
+
min=3,
|
|
59
|
+
description="Min Neighs",
|
|
60
|
+
continuous_update=False,
|
|
61
|
+
),
|
|
62
|
+
weight_by=Dropdown(
|
|
63
|
+
options=["connectivity", "distance", "constant"],
|
|
64
|
+
value="connectivity",
|
|
65
|
+
description="Weight By",
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
super().__init__(
|
|
70
|
+
parent=parent,
|
|
71
|
+
fixed_params={},
|
|
72
|
+
variable_controls=variable_controls,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self._update_groupby_options()
|
|
76
|
+
self._update_column_options()
|
|
77
|
+
self._update_numeric_column_controls()
|
|
78
|
+
|
|
79
|
+
self.variable_controls["group_key"].observe(
|
|
80
|
+
self._update_source_group_options, "value", "change"
|
|
81
|
+
)
|
|
82
|
+
self.variable_controls["column"].observe(
|
|
83
|
+
self._update_numeric_column_controls, "value", "change"
|
|
84
|
+
)
|
|
85
|
+
self.variable_controls["periodic"].observe(
|
|
86
|
+
self._update_vmin_vmax_visibility, "value", "change"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def _update_groupby_options(self, *args, **kwargs):
|
|
90
|
+
metadata = self.parent.dataset._metadata.select_dtypes(include=["category"])
|
|
91
|
+
options = {"": None, **{c: c for c in metadata.columns}}
|
|
92
|
+
self.variable_controls["group_key"].options = options
|
|
93
|
+
|
|
94
|
+
def _update_source_group_options(self, *args, **kwargs):
|
|
95
|
+
group_key = self.variable_controls["group_key"].value
|
|
96
|
+
if group_key is None:
|
|
97
|
+
self.variable_controls["source_group"].options = ("",)
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
options = self.parent.dataset._metadata[group_key].sort_values().unique()
|
|
101
|
+
options = {"": None, **{c: c for c in options}}
|
|
102
|
+
self.variable_controls["source_group"].options = options
|
|
103
|
+
|
|
104
|
+
def _update_column_options(self, *args, **kwargs):
|
|
105
|
+
metadata = self.parent.dataset._metadata.select_dtypes(
|
|
106
|
+
include=["category", "bool", "number"]
|
|
107
|
+
)
|
|
108
|
+
options = {"": None, **{c: c for c in metadata.columns}}
|
|
109
|
+
self.variable_controls["column"].options = options
|
|
110
|
+
|
|
111
|
+
def _update_numeric_column_controls(self, *args, **kwargs):
|
|
112
|
+
column = self.variable_controls["column"].value
|
|
113
|
+
if column is None:
|
|
114
|
+
self._hide_control(self.variable_controls["periodic"])
|
|
115
|
+
self._hide_control(self.variable_controls["vmin"])
|
|
116
|
+
self._hide_control(self.variable_controls["vmax"])
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
series = self.parent.dataset._metadata[column]
|
|
120
|
+
periodic = self.variable_controls["periodic"].value
|
|
121
|
+
|
|
122
|
+
if is_numeric_dtype(series):
|
|
123
|
+
self._show_control(self.variable_controls["periodic"])
|
|
124
|
+
if periodic:
|
|
125
|
+
self._show_control(self.variable_controls["vmin"])
|
|
126
|
+
self._show_control(self.variable_controls["vmax"])
|
|
127
|
+
else:
|
|
128
|
+
self._hide_control(self.variable_controls["periodic"])
|
|
129
|
+
self._hide_control(self.variable_controls["vmin"])
|
|
130
|
+
self._hide_control(self.variable_controls["vmax"])
|
|
131
|
+
|
|
132
|
+
def _update_vmin_vmax_visibility(self, *args, **kwargs):
|
|
133
|
+
periodic = self.variable_controls["periodic"].value
|
|
134
|
+
|
|
135
|
+
if periodic:
|
|
136
|
+
self._show_control(self.variable_controls["vmin"])
|
|
137
|
+
self._show_control(self.variable_controls["vmax"])
|
|
138
|
+
else:
|
|
139
|
+
self._hide_control(self.variable_controls["vmin"])
|
|
140
|
+
self._hide_control(self.variable_controls["vmax"])
|
|
141
|
+
|
|
142
|
+
def _hide_control(self, control):
|
|
143
|
+
control.layout.visibility = "hidden"
|
|
144
|
+
control.layout.height = "0px"
|
|
145
|
+
|
|
146
|
+
def _show_control(self, control):
|
|
147
|
+
control.layout.visibility = "visible"
|
|
148
|
+
control.layout.height = "28px"
|
|
149
|
+
|
|
150
|
+
def function(
|
|
151
|
+
self,
|
|
152
|
+
group_key: str,
|
|
153
|
+
source_group: str,
|
|
154
|
+
column: str,
|
|
155
|
+
periodic: bool = False,
|
|
156
|
+
vmin: float = 0,
|
|
157
|
+
vmax: float = 1,
|
|
158
|
+
min_neighs: int = 5,
|
|
159
|
+
weight_by: Literal["connectivity", "distance", "constant"] = "connectivity",
|
|
160
|
+
**kwargs,
|
|
161
|
+
):
|
|
162
|
+
from ...preprocess._transfer_metadata import transfer_metadata
|
|
163
|
+
|
|
164
|
+
self.output.clear_output(wait=True)
|
|
165
|
+
with self.output:
|
|
166
|
+
transfer_metadata(
|
|
167
|
+
self.parent.dataset.adata,
|
|
168
|
+
group_key=group_key,
|
|
169
|
+
source_group=source_group,
|
|
170
|
+
column=column,
|
|
171
|
+
periodic=periodic,
|
|
172
|
+
vmin=vmin,
|
|
173
|
+
vmax=vmax,
|
|
174
|
+
min_neighs=min_neighs,
|
|
175
|
+
weight_by=weight_by,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
new_column = f"transferred_{column}"
|
|
179
|
+
|
|
180
|
+
self.broker.publish(
|
|
181
|
+
"dset_metadata_change", self.parent.dataset.metadata, new_column
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def dset_metadata_change_callback(self, *args, **kwargs):
|
|
185
|
+
self._update_groupby_options(*args, **kwargs)
|
|
186
|
+
self._update_column_options(*args, **kwargs)
|