sclab 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sclab might be problematic. Click here for more details.

Files changed (53) hide show
  1. sclab/__init__.py +1 -1
  2. sclab/_sclab.py +7 -3
  3. sclab/dataset/_dataset.py +1 -1
  4. sclab/dataset/processor/_processor.py +19 -4
  5. sclab/examples/processor_steps/__init__.py +2 -0
  6. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  7. sclab/examples/processor_steps/_integration.py +47 -20
  8. sclab/examples/processor_steps/_neighbors.py +24 -4
  9. sclab/examples/processor_steps/_pca.py +11 -6
  10. sclab/examples/processor_steps/_preprocess.py +14 -1
  11. sclab/examples/processor_steps/_qc.py +22 -6
  12. sclab/gui/__init__.py +0 -0
  13. sclab/gui/components/__init__.py +7 -0
  14. sclab/gui/components/_guided_pseudotime.py +482 -0
  15. sclab/gui/components/_transfer_metadata.py +186 -0
  16. sclab/methods/__init__.py +16 -0
  17. sclab/preprocess/__init__.py +19 -0
  18. sclab/preprocess/_cca.py +154 -0
  19. sclab/preprocess/_cca_integrate.py +109 -0
  20. sclab/preprocess/_filter_obs.py +42 -0
  21. sclab/preprocess/_harmony.py +421 -0
  22. sclab/preprocess/_harmony_integrate.py +53 -0
  23. sclab/preprocess/_normalize_weighted.py +61 -0
  24. sclab/preprocess/_subset.py +208 -0
  25. sclab/preprocess/_transfer_metadata.py +137 -0
  26. sclab/preprocess/_transform.py +82 -0
  27. sclab/preprocess/_utils.py +96 -0
  28. sclab/tools/__init__.py +0 -0
  29. sclab/tools/cellflow/__init__.py +0 -0
  30. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  31. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  32. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  33. sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
  34. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  35. sclab/tools/cellflow/utils/__init__.py +0 -0
  36. sclab/tools/cellflow/utils/density_nd.py +215 -0
  37. sclab/tools/cellflow/utils/interpolate.py +334 -0
  38. sclab/tools/cellflow/utils/smoothen.py +124 -0
  39. sclab/tools/cellflow/utils/times.py +55 -0
  40. sclab/tools/differential_expression/__init__.py +5 -0
  41. sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
  42. sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
  43. sclab/tools/doublet_detection/__init__.py +5 -0
  44. sclab/tools/doublet_detection/_scrublet.py +64 -0
  45. sclab/tools/labeling/__init__.py +6 -0
  46. sclab/tools/labeling/sctype.py +233 -0
  47. sclab/utils/__init__.py +5 -0
  48. sclab/utils/_write_excel.py +510 -0
  49. {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/METADATA +6 -2
  50. sclab-0.3.1.dist-info/RECORD +82 -0
  51. sclab-0.2.5.dist-info/RECORD +0 -45
  52. {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/WHEEL +0 -0
  53. {sclab-0.2.5.dist-info → sclab-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,482 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from ipywidgets import Button, Dropdown, FloatLogSlider, FloatSlider, HBox, Text
6
+ from numpy import floating
7
+ from numpy.typing import NDArray
8
+ from plotly.subplots import make_subplots
9
+ from sclab.dataset.processor import Processor
10
+ from sclab.dataset.processor.step import ProcessorStepBase
11
+
12
+ # TODO: remove self.drawn_path and self.drawn_path_residue from the class
13
+ # and add them to the dataset as _drawn_path and _drawn_path_residue
14
+
15
+
16
+ _2PI = 2 * np.pi
17
+
18
+
19
+ class GuidedPseudotime(ProcessorStepBase):
20
+ parent: Processor
21
+ name: str = "guided_pseudotime"
22
+ description: str = "Guided Pseudotime"
23
+
24
+ run_button_description = "Compute Pseudotime"
25
+
26
+ def __init__(self, parent: Processor) -> None:
27
+ variable_controls = dict(
28
+ residue_threshold=FloatLogSlider(
29
+ value=1,
30
+ min=-3,
31
+ max=0,
32
+ description="Filter by Dist.",
33
+ continuous_update=True,
34
+ ),
35
+ use_rep=Dropdown(
36
+ options=tuple(parent.dataset.adata.obsm.keys()),
37
+ value=None,
38
+ description="Use rep.",
39
+ ),
40
+ roughness=FloatSlider(
41
+ value=0,
42
+ min=0,
43
+ max=3,
44
+ step=0.05,
45
+ description="Roughness",
46
+ continuous_update=False,
47
+ ),
48
+ min_snr=FloatSlider(
49
+ value=0.25,
50
+ min=0,
51
+ max=1,
52
+ step=0.05,
53
+ description="SNR",
54
+ continuous_update=False,
55
+ ),
56
+ key_added=Text(
57
+ value="pseudotime",
58
+ description="Key added",
59
+ placeholder="",
60
+ ),
61
+ )
62
+
63
+ super().__init__(
64
+ parent=parent,
65
+ fixed_params={},
66
+ variable_controls=variable_controls,
67
+ )
68
+ self.run_button.disabled = True
69
+ self.run_button.button_style = ""
70
+ self.estimate_start_time_button.layout.visibility = "hidden"
71
+ self.estimate_start_time_button.layout.height = "0px"
72
+
73
+ self.start_drawing_button.on_click(self.toggle_drawing_callback)
74
+ self.plot_signal.on_click(self.send_signal_plot)
75
+ self.estimate_start_time_button.on_click(
76
+ self.estimate_periodic_pseudotime_start
77
+ )
78
+ self.plot_fitted_pseudotime_curve.on_click(self.send_fitted_pseudotime_plot)
79
+ self.variable_controls["use_rep"].observe(
80
+ self.update_buttons_state, names="value", type="change"
81
+ )
82
+ self.variable_controls["residue_threshold"].observe(
83
+ self._assign_drawn_path_values, names="value", type="change"
84
+ )
85
+
86
+ self.variable_controls["use_rep"].observe(
87
+ self.close_signal_plot, names="value", type="change"
88
+ )
89
+ self.variable_controls["roughness"].observe(
90
+ self.close_signal_plot, names="value", type="change"
91
+ )
92
+ self.variable_controls["min_snr"].observe(
93
+ self.close_signal_plot, names="value", type="change"
94
+ )
95
+
96
+ def make_controls(self):
97
+ self.start_drawing_button = Button(
98
+ description="Start Drawing", button_style="primary"
99
+ )
100
+
101
+ self.auto_drawing_button = Button(
102
+ description="Automatic Drawing", button_style="primary"
103
+ )
104
+ self.auto_drawing_button.on_click(self._automatic_periodic_path_drawing)
105
+
106
+ self.plot_signal = Button(
107
+ description="Plot Signal", button_style="info", disabled=True
108
+ )
109
+
110
+ self.estimate_start_time_button = Button(
111
+ description="Estimate Pseudotime Start",
112
+ button_style="",
113
+ disabled=True,
114
+ )
115
+
116
+ self.plot_fitted_pseudotime_curve = Button(
117
+ description="Plot Fitted Pseudotime Curve",
118
+ button_style="info",
119
+ disabled=True,
120
+ )
121
+
122
+ self.controls_list = [
123
+ HBox([self.auto_drawing_button, self.start_drawing_button]),
124
+ *self.variable_controls.values(),
125
+ self.plot_signal,
126
+ self.run_button,
127
+ self.plot_fitted_pseudotime_curve,
128
+ self.estimate_start_time_button,
129
+ self.output,
130
+ ]
131
+ super().make_controls()
132
+
133
+ def update_buttons_state(self, *args, **kwargs):
134
+ drawing = self.start_drawing_button.description != "Start Drawing"
135
+ if self.variable_controls["use_rep"].value is None or drawing:
136
+ self.run_button.disabled = True
137
+ self.run_button.button_style = ""
138
+
139
+ self.plot_signal.disabled = True
140
+ self.plot_signal.button_style = ""
141
+
142
+ self.plot_fitted_pseudotime_curve.disabled = True
143
+ self.plot_fitted_pseudotime_curve.button_style = ""
144
+
145
+ self.estimate_start_time_button.disabled = True
146
+ self.estimate_start_time_button.button_style = ""
147
+ self.estimate_start_time_button.layout.visibility = "hidden"
148
+ self.estimate_start_time_button.layout.height = "0px"
149
+
150
+ return
151
+
152
+ self.run_button.disabled = False
153
+ self.run_button.button_style = "primary"
154
+
155
+ self.plot_signal.disabled = False
156
+ self.plot_signal.button_style = "info"
157
+
158
+ self.plot_fitted_pseudotime_curve.disabled = False
159
+ self.plot_fitted_pseudotime_curve.button_style = "info"
160
+
161
+ def toggle_drawing_callback(self, _: Button | None = None):
162
+ if self.start_drawing_button.description == "Start Drawing":
163
+ self.start_drawing_button.disabled = False
164
+ self.start_drawing_button.button_style = "warning"
165
+ self.start_drawing_button.description = "--> click here when ready <--"
166
+
167
+ self.update_buttons_state()
168
+
169
+ self.broker.publish("dplt_start_drawing_request")
170
+ self.update_output("Use your mouse pointer to draw a path on the figure")
171
+ else:
172
+ self.start_drawing_button.disabled = False
173
+ self.start_drawing_button.button_style = "primary"
174
+ self.start_drawing_button.description = "Start Drawing"
175
+
176
+ self.update_buttons_state()
177
+
178
+ self.broker.publish("dplt_end_drawing_request")
179
+ self.update_output(
180
+ "Click on the **Run** button to fit a pseudotime curve to the data"
181
+ + " points. Make sure to select a data representation before"
182
+ + " running the analysis."
183
+ )
184
+
185
+ def estimate_periodic_pseudotime_start(self, _: Button | None = None):
186
+ from ...tools.cellflow.pseudotime._pseudotime import estimate_periodic_pseudotime_start
187
+
188
+ time_key = self.variable_controls["key_added"].value
189
+ estimate_periodic_pseudotime_start(self.parent.dataset.adata, time_key=time_key)
190
+ self.broker.publish(
191
+ "dset_metadata_change", self.parent.dataset.metadata, time_key
192
+ )
193
+ self.estimate_start_time_button.button_style = "success"
194
+
195
+ def send_signal_plot(self, _: Button | None = None):
196
+ from ...tools.cellflow.utils.interpolate import NDBSpline
197
+
198
+ if self.plot_signal.description == "Plot Signal":
199
+ adata = self.parent.dataset.adata
200
+ use_rep = self.variable_controls["use_rep"].value
201
+ roughness = self.variable_controls["roughness"].value
202
+ min_snr = self.variable_controls["min_snr"].value
203
+ periodic = self.parent.dataset.adata.uns["drawn_path"]["path_is_closed"]
204
+
205
+ t_range = (0.0, 1.0)
206
+ tmin, tmax = t_range
207
+
208
+ t = adata.obs["drawn_path"].values
209
+ X = adata.obsm[use_rep]
210
+
211
+ df = pd.DataFrame(
212
+ X,
213
+ columns=[f"Dim {i + 1}" for i in range(X.shape[1])],
214
+ index=adata.obs_names,
215
+ )
216
+ df = df.join(self.parent.dataset.metadata)
217
+ df["index"] = df.index
218
+
219
+ t_mask = (tmin <= t) * (t <= tmax)
220
+ t = t[t_mask]
221
+ X = X[t_mask]
222
+ df = df.loc[t_mask]
223
+
224
+ max_dims = 16
225
+ ndims = min(X.shape[1], max_dims)
226
+
227
+ F = NDBSpline(t_range=t_range, periodic=periodic, roughness=roughness)
228
+ F.fit(t, X)
229
+
230
+ SNR: NDArray[floating] = F(t).var(axis=0) / X.var(axis=0)
231
+ SNR = SNR / SNR.max()
232
+
233
+ x = np.linspace(*t_range, 200)
234
+ Y = F(x)
235
+
236
+ rows = cols = int(np.ceil(np.sqrt(ndims)))
237
+ titles = [f"Dim {i + 1}. SNR: {SNR[i]:.2f}" for i in range(ndims)]
238
+ fig = make_subplots(
239
+ rows=rows,
240
+ cols=cols,
241
+ shared_xaxes=True,
242
+ shared_yaxes=False,
243
+ x_title="Drawn path",
244
+ y_title="Signal",
245
+ subplot_titles=titles,
246
+ )
247
+
248
+ for i in range(ndims):
249
+ row = i // cols + 1
250
+ col = i % cols + 1
251
+ snr = SNR[i]
252
+ marker_color = "blue" if snr >= min_snr else "lightgray"
253
+ line_color = "red" if snr >= min_snr else "gray"
254
+
255
+ scatter = px.scatter(
256
+ df,
257
+ x="drawn_path",
258
+ y=f"Dim {i + 1}",
259
+ template="simple_white",
260
+ hover_name="index",
261
+ )
262
+ scatter.update_traces(marker=dict(size=5, color=marker_color))
263
+
264
+ for trace in scatter.data:
265
+ fig.add_trace(trace, row=row, col=col)
266
+
267
+ line = go.Scattergl(
268
+ x=x,
269
+ y=Y[:, i],
270
+ mode="lines",
271
+ line_color=line_color,
272
+ )
273
+ fig.add_trace(line, row=row, col=col)
274
+
275
+ fig.update_layout(showlegend=False, title=f"{use_rep} Signal Plot")
276
+ self.plot_signal.description = "Close Signal Plot"
277
+ self.plot_signal.button_style = "warning"
278
+
279
+ else:
280
+ fig = None
281
+ self.plot_signal.description = "Plot Signal"
282
+ self.plot_signal.button_style = "info"
283
+
284
+ self.broker.publish("dplt_plot_figure_request", figure=fig)
285
+
286
+ def close_signal_plot(self, *args, **kwargs):
287
+ self.plot_signal.description = "Plot Signal"
288
+ self.plot_signal.button_style = "info"
289
+ self.broker.publish("dplt_plot_figure_request", figure=None)
290
+
291
+ def function(
292
+ self,
293
+ use_rep: str,
294
+ roughness: float,
295
+ min_snr: float,
296
+ key_added: str,
297
+ **kwargs,
298
+ ):
299
+ from ...tools.cellflow.pseudotime._pseudotime import pseudotime
300
+
301
+ self.plot_signal.description = "Plot Signal"
302
+ self.plot_signal.button_style = "info"
303
+
304
+ periodic = self.parent.dataset.adata.uns["drawn_path"]["path_is_closed"]
305
+
306
+ self.output.clear_output(wait=True)
307
+ with self.output:
308
+ pseudotime(
309
+ adata=self.parent.dataset.adata,
310
+ use_rep=use_rep,
311
+ t_key="drawn_path",
312
+ t_range=(0.0, 1.0),
313
+ min_snr=min_snr,
314
+ periodic=periodic,
315
+ method="splines",
316
+ roughness=roughness,
317
+ key_added=key_added,
318
+ )
319
+
320
+ self.parent.dataset.clear_selected_rows()
321
+ self.broker.publish("ctrl_data_key_value_change_request", use_rep)
322
+ self.broker.publish(
323
+ "dset_metadata_change", self.parent.dataset.metadata, key_added
324
+ )
325
+
326
+ self.send_fitted_pseudotime_plot()
327
+ self.update_output("")
328
+
329
+ if periodic:
330
+ self.estimate_start_time_button.disabled = False
331
+ self.estimate_start_time_button.button_style = "primary"
332
+ self.estimate_start_time_button.layout.visibility = "visible"
333
+ self.estimate_start_time_button.layout.height = "28px"
334
+
335
+ def send_fitted_pseudotime_plot(self, *args, **kwargs):
336
+ use_rep = self.variable_controls["use_rep"].value
337
+ key_added = self.variable_controls["key_added"].value
338
+
339
+ t: NDArray = self.parent.dataset.adata.obs[key_added].values
340
+ t_mask = ~np.isnan(t)
341
+ t = t[t_mask]
342
+
343
+ X_path = self.parent.dataset.adata.obsm[f"{key_added}_path"]
344
+ data = self.parent.dataset.data.copy()
345
+ data.values[:] = X_path
346
+ data: pd.DataFrame = data.loc[t_mask]
347
+ data = data.iloc[t.argsort()]
348
+ self.broker.publish(
349
+ "dplt_add_data_as_line_trace_request",
350
+ use_rep,
351
+ data,
352
+ name=key_added,
353
+ line_color="red",
354
+ )
355
+
356
+ def _assign_drawn_path_values(self, *args, **kwargs):
357
+ dataset = self.parent.dataset
358
+
359
+ drawn_path = self.drawn_path.copy()
360
+ drawn_path_residue = self.drawn_path_residue.copy()
361
+
362
+ residue_threshold = self.variable_controls["residue_threshold"].value
363
+ x = drawn_path_residue / drawn_path_residue.max()
364
+ drawn_path.loc[x > residue_threshold] = np.nan
365
+
366
+ # detecting outliers: points with projection to the curve, but that are not
367
+ # part of the cluster of points where the user drew the path
368
+ x = drawn_path_residue.loc[drawn_path.notna()]
369
+ x = x[~np.isnan(x)].values
370
+ # sort in descending order
371
+ x = np.sort(x)[::-1]
372
+ # normalize
373
+ y = x / x.max()
374
+ # detect jumps in the normalized values
375
+ d = y[:-1] - y[1:]
376
+ if (d > 0.25).any():
377
+ # if there is a spike in the residue values, and the spike is larger than
378
+ # 25% of the maximum residue value, then remove the points with residue
379
+ # values larger than the threshold
380
+ thr = x[:-1][d > 0.25].min()
381
+ drawn_path.loc[drawn_path_residue >= thr] = np.nan
382
+
383
+ selected: NDArray = drawn_path.notna().values
384
+ if selected.any():
385
+ dataset.selected_rows = dataset.row_names[selected]
386
+ else:
387
+ dataset.selected_rows = None
388
+
389
+ publish_change = "drawn_path" not in dataset._metadata
390
+ dataset._metadata["drawn_path"] = drawn_path.clip(0, 1)
391
+ if publish_change:
392
+ self.broker.publish("dset_metadata_change", dataset.metadata)
393
+
394
+ def _automatic_periodic_path_drawing(self, *args, **kwargs):
395
+ from ...tools.cellflow.pseudotime._pseudotime import periodic_parameter
396
+ from ...tools.cellflow.utils.interpolate import NDFourier
397
+
398
+ data_points_array = self.parent.plotter.data_for_plot.values[:, :2]
399
+ ordr_points_array = periodic_parameter(data_points_array) / _2PI
400
+ F = NDFourier(t_range=(0, 1), grid_size=128, smoothing_fn=np.median)
401
+ F.fit(ordr_points_array, data_points_array)
402
+
403
+ T = time_points = np.linspace(0, 1, 1024 + 1)
404
+ path_points_array = F(time_points)
405
+
406
+ X = data_points_array[:, 0] + 1j * data_points_array[:, 1]
407
+ P = path_points_array[:, 0] + 1j * path_points_array[:, 1]
408
+ self._compute_drawn_path(T, X, P, path_is_closed=True)
409
+ self._assign_drawn_path_values()
410
+ line = go.Scattergl(x=P.real, y=P.imag, mode="lines", line_color="black")
411
+ self.broker.publish("dplt_add_trace_request", trace=line)
412
+
413
+ def _compute_drawn_path(
414
+ self,
415
+ time_points: NDArray[floating],
416
+ data_points: NDArray[floating],
417
+ path_points: NDArray[floating],
418
+ path_is_closed: bool,
419
+ ):
420
+ dataset = self.parent.dataset
421
+
422
+ T = time_points
423
+ X = data_points
424
+ P = path_points
425
+
426
+ idxs = np.sort(
427
+ np.argsort(np.abs(X[:, None] - P[None, :]), axis=1)[:, :3], axis=1
428
+ )
429
+
430
+ drawn_path = pd.Series(index=dataset._metadata.index, dtype=float)
431
+ drawn_path_residue = pd.Series(index=dataset._metadata.index, dtype=float)
432
+
433
+ T1, T2, T3 = T[idxs].T
434
+ P1, P2, P3 = P[idxs].T
435
+
436
+ A = P2 - P1
437
+ B = X - P1
438
+ C = P2 - X
439
+ d = B.real * A.real + B.imag * A.imag
440
+ e = A.real * C.real + A.imag * C.imag
441
+ gap = np.abs(B - d * A / np.abs(A) ** 2)
442
+ m = mask1 = (d > 0) * (e > 0)
443
+ pseudotime = T1[m] + d[m] / np.abs(A[m]) * (T2[m] - T1[m])
444
+ drawn_path.loc[m] = pseudotime
445
+ drawn_path.loc[m] = pseudotime
446
+ drawn_path_residue.loc[m] = gap[m]
447
+
448
+ if (~mask1).sum() > 0:
449
+ A = P3[~mask1] - P1[~mask1]
450
+ B = X[~mask1] - P1[~mask1]
451
+ C = P3[~mask1] - X[~mask1]
452
+ d = B.real * A.real + B.imag * A.imag
453
+ e = A.real * C.real + A.imag * C.imag
454
+ gap = np.abs(B - d * A / np.abs(A) ** 2)
455
+ m = mask2 = np.zeros_like(mask1)
456
+ mask2[~mask1] = submsk = (d > 0) * (e > 0)
457
+ pseudotime = T1[m] + d[submsk] / np.abs(A[submsk]) * (T3[m] - T1[m])
458
+ drawn_path.loc[m] = pseudotime
459
+ drawn_path_residue.loc[m] = gap[submsk]
460
+
461
+ self.drawn_path = drawn_path.copy()
462
+ self.drawn_path_residue = drawn_path_residue.copy()
463
+
464
+ dataset.adata.uns["drawn_path"] = dict(
465
+ t_range=[0.0, 1.0],
466
+ periodic=path_is_closed,
467
+ path_is_closed=path_is_closed,
468
+ )
469
+
470
+ def dplt_soft_path_computed_callback(
471
+ self,
472
+ time_points: NDArray[floating],
473
+ data_points: NDArray[floating],
474
+ path_points: NDArray[floating],
475
+ path_is_closed: bool,
476
+ ):
477
+ self._compute_drawn_path(time_points, data_points, path_points, path_is_closed)
478
+ self._assign_drawn_path_values()
479
+ line = go.Scattergl(
480
+ x=path_points.real, y=path_points.imag, mode="lines", line_color="black"
481
+ )
482
+ self.broker.publish("dplt_add_trace_request", trace=line)
@@ -0,0 +1,186 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ from ipywidgets import (
5
+ Checkbox,
6
+ Dropdown,
7
+ FloatText,
8
+ IntText,
9
+ )
10
+ from pandas.api.types import is_numeric_dtype
11
+
12
+ from sclab.dataset.processor import Processor
13
+ from sclab.dataset.processor.step import ProcessorStepBase
14
+
15
+ _2PI = 2 * np.pi
16
+
17
+
18
+ class TransferMetadata(ProcessorStepBase):
19
+ parent: Processor
20
+ name: str = "transfer_metadata"
21
+ description: str = "Transfer Metadata"
22
+
23
+ run_button_description = "Transfer Metadata"
24
+
25
+ def __init__(self, parent: Processor) -> None:
26
+ variable_controls = dict(
27
+ group_key=Dropdown(
28
+ options=[],
29
+ value=None,
30
+ description="Group Key",
31
+ ),
32
+ source_group=Dropdown(
33
+ options=[],
34
+ value=None,
35
+ description="Source Group",
36
+ ),
37
+ column=Dropdown(
38
+ options=[],
39
+ value=None,
40
+ description="Column",
41
+ ),
42
+ periodic=Checkbox(
43
+ value=False,
44
+ description="Periodic",
45
+ ),
46
+ vmin=FloatText(
47
+ value=0,
48
+ description="Vmin",
49
+ continuous_update=False,
50
+ ),
51
+ vmax=FloatText(
52
+ value=1,
53
+ description="Vmax",
54
+ continuous_update=False,
55
+ ),
56
+ min_neighs=IntText(
57
+ value=5,
58
+ min=3,
59
+ description="Min Neighs",
60
+ continuous_update=False,
61
+ ),
62
+ weight_by=Dropdown(
63
+ options=["connectivity", "distance", "constant"],
64
+ value="connectivity",
65
+ description="Weight By",
66
+ ),
67
+ )
68
+
69
+ super().__init__(
70
+ parent=parent,
71
+ fixed_params={},
72
+ variable_controls=variable_controls,
73
+ )
74
+
75
+ self._update_groupby_options()
76
+ self._update_column_options()
77
+ self._update_numeric_column_controls()
78
+
79
+ self.variable_controls["group_key"].observe(
80
+ self._update_source_group_options, "value", "change"
81
+ )
82
+ self.variable_controls["column"].observe(
83
+ self._update_numeric_column_controls, "value", "change"
84
+ )
85
+ self.variable_controls["periodic"].observe(
86
+ self._update_vmin_vmax_visibility, "value", "change"
87
+ )
88
+
89
+ def _update_groupby_options(self, *args, **kwargs):
90
+ metadata = self.parent.dataset._metadata.select_dtypes(include=["category"])
91
+ options = {"": None, **{c: c for c in metadata.columns}}
92
+ self.variable_controls["group_key"].options = options
93
+
94
+ def _update_source_group_options(self, *args, **kwargs):
95
+ group_key = self.variable_controls["group_key"].value
96
+ if group_key is None:
97
+ self.variable_controls["source_group"].options = ("",)
98
+ return
99
+
100
+ options = self.parent.dataset._metadata[group_key].sort_values().unique()
101
+ options = {"": None, **{c: c for c in options}}
102
+ self.variable_controls["source_group"].options = options
103
+
104
+ def _update_column_options(self, *args, **kwargs):
105
+ metadata = self.parent.dataset._metadata.select_dtypes(
106
+ include=["category", "bool", "number"]
107
+ )
108
+ options = {"": None, **{c: c for c in metadata.columns}}
109
+ self.variable_controls["column"].options = options
110
+
111
+ def _update_numeric_column_controls(self, *args, **kwargs):
112
+ column = self.variable_controls["column"].value
113
+ if column is None:
114
+ self._hide_control(self.variable_controls["periodic"])
115
+ self._hide_control(self.variable_controls["vmin"])
116
+ self._hide_control(self.variable_controls["vmax"])
117
+ return
118
+
119
+ series = self.parent.dataset._metadata[column]
120
+ periodic = self.variable_controls["periodic"].value
121
+
122
+ if is_numeric_dtype(series):
123
+ self._show_control(self.variable_controls["periodic"])
124
+ if periodic:
125
+ self._show_control(self.variable_controls["vmin"])
126
+ self._show_control(self.variable_controls["vmax"])
127
+ else:
128
+ self._hide_control(self.variable_controls["periodic"])
129
+ self._hide_control(self.variable_controls["vmin"])
130
+ self._hide_control(self.variable_controls["vmax"])
131
+
132
+ def _update_vmin_vmax_visibility(self, *args, **kwargs):
133
+ periodic = self.variable_controls["periodic"].value
134
+
135
+ if periodic:
136
+ self._show_control(self.variable_controls["vmin"])
137
+ self._show_control(self.variable_controls["vmax"])
138
+ else:
139
+ self._hide_control(self.variable_controls["vmin"])
140
+ self._hide_control(self.variable_controls["vmax"])
141
+
142
+ def _hide_control(self, control):
143
+ control.layout.visibility = "hidden"
144
+ control.layout.height = "0px"
145
+
146
+ def _show_control(self, control):
147
+ control.layout.visibility = "visible"
148
+ control.layout.height = "28px"
149
+
150
+ def function(
151
+ self,
152
+ group_key: str,
153
+ source_group: str,
154
+ column: str,
155
+ periodic: bool = False,
156
+ vmin: float = 0,
157
+ vmax: float = 1,
158
+ min_neighs: int = 5,
159
+ weight_by: Literal["connectivity", "distance", "constant"] = "connectivity",
160
+ **kwargs,
161
+ ):
162
+ from ...preprocess._transfer_metadata import transfer_metadata
163
+
164
+ self.output.clear_output(wait=True)
165
+ with self.output:
166
+ transfer_metadata(
167
+ self.parent.dataset.adata,
168
+ group_key=group_key,
169
+ source_group=source_group,
170
+ column=column,
171
+ periodic=periodic,
172
+ vmin=vmin,
173
+ vmax=vmax,
174
+ min_neighs=min_neighs,
175
+ weight_by=weight_by,
176
+ )
177
+
178
+ new_column = f"transferred_{column}"
179
+
180
+ self.broker.publish(
181
+ "dset_metadata_change", self.parent.dataset.metadata, new_column
182
+ )
183
+
184
+ def dset_metadata_change_callback(self, *args, **kwargs):
185
+ self._update_groupby_options(*args, **kwargs)
186
+ self._update_column_options(*args, **kwargs)