sclab 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1017 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ import svgpathtools as spt
8
+ from ipywidgets.widgets import (
9
+ Box,
10
+ GridBox,
11
+ Layout,
12
+ )
13
+ from numpy.typing import NDArray
14
+ from pandas.api.types import is_any_real_numeric_dtype, is_bool_dtype
15
+ from plotly.graph_objs.scatter import Marker as Marker2D
16
+ from plotly.graph_objs.scatter3d import Marker as Marker3D
17
+ from scipy.interpolate import make_smoothing_spline
18
+
19
+ from ...event import EventBroker, EventClient
20
+ from .._dataset import SCLabDataset
21
+ from ._controls import PlotterControls
22
+ from ._utils import (
23
+ COLOR_DISCRETE_SEQUENCE,
24
+ _make_density_heatmap,
25
+ make_periodic_smoothing_spline,
26
+ rotate_multiple_steps,
27
+ )
28
+
29
+
30
+ class Plotter(GridBox, EventClient):
31
+ g: go.FigureWidget
32
+ controls: PlotterControls
33
+ dataset: SCLabDataset
34
+ state: dict[str, Any]
35
+
36
+ events: list[str] = [
37
+ "dplt_dataset_change",
38
+ "dplt_selected_points_change",
39
+ "dplt_point_click",
40
+ "dplt_layout_shapes_change",
41
+ "dplt_start_drawing_request",
42
+ "dplt_end_drawing_request",
43
+ "dplt_soft_path_computed",
44
+ "dplt_plot_figure_request",
45
+ "dplt_add_trace_request",
46
+ "dplt_add_vline_request",
47
+ "dplt_add_hline_request",
48
+ "dplt_add_data_as_line_trace_request",
49
+ ]
50
+ preemptions: dict[str, list[str]] = {
51
+ "dplt_selected_points_change": [
52
+ "dspr_selection_values_change",
53
+ ],
54
+ "dplt_plot_figure_request": [
55
+ "ctrl_data_key_change",
56
+ "ctrl_color_change",
57
+ "ctrl_sizeby_change",
58
+ ],
59
+ }
60
+
61
+ def __init__(
62
+ self,
63
+ dataset: SCLabDataset | None = None,
64
+ broker: EventBroker | None = None,
65
+ ):
66
+ if broker is None:
67
+ assert dataset is not None, "dataset must be provided if broker is None"
68
+ broker = dataset.broker
69
+
70
+ EventClient.__init__(self, broker)
71
+
72
+ self.modebar = dict(
73
+ add=[],
74
+ )
75
+ self.state = {}
76
+
77
+ self._init_figure()
78
+ self._init_controls()
79
+ self.load_dataset(dataset)
80
+
81
+ graph_layout = Layout(
82
+ display="block", width="100%", height="600px", border="0px solid red"
83
+ )
84
+ self.graph_box = Box([self.g], layout=graph_layout)
85
+
86
+ GridBox.__init__(
87
+ self,
88
+ [self.graph_box, self.controls],
89
+ layout=Layout(
90
+ width="100%",
91
+ grid_template_columns="auto 350px",
92
+ grid_template_areas=""" "graph controls" """,
93
+ border="0px solid black",
94
+ ),
95
+ )
96
+ self.broker.subscribe("ctrl_data_key_change", self.make_new_figure)
97
+ self.broker.subscribe("ctrl_color_change", self.make_new_figure)
98
+ self.broker.subscribe("ctrl_n_dimensions_change", self.make_new_figure)
99
+ self.broker.subscribe("ctrl_selected_axes_1_change", self.make_new_figure)
100
+ self.broker.subscribe("ctrl_selected_axes_2_change", self.make_new_figure)
101
+ self.broker.subscribe("ctrl_selected_axes_3_change", self.make_new_figure)
102
+ self.broker.subscribe("ctrl_log_axes_1_change", self.make_new_figure)
103
+ self.broker.subscribe("ctrl_log_axes_2_change", self.make_new_figure)
104
+ self.broker.subscribe("ctrl_log_axes_3_change", self.make_new_figure)
105
+ self.broker.subscribe("ctrl_rotate_steps_change", self.make_new_figure)
106
+ self.broker.subscribe("ctrl_refresh_button_click", self.make_new_figure)
107
+ self.broker.subscribe("ctrl_histogram_nbins_change", self.make_new_figure)
108
+ self.broker.subscribe("dset_selected_rows_change", self.select_points)
109
+ self.broker.subscribe("dset_total_rows_change", self.make_new_figure)
110
+ self.broker.subscribe("dset_metadata_change", self.make_new_figure)
111
+
112
+ def _init_figure(self):
113
+ self.g = go.FigureWidget(
114
+ dict(
115
+ layout=dict(
116
+ xaxis_title="",
117
+ yaxis_title="",
118
+ title="",
119
+ template="simple_white",
120
+ height=600,
121
+ modebar=self.modebar,
122
+ )
123
+ )
124
+ )
125
+
126
+ def relayout_publisher(event):
127
+ if not event["new"]:
128
+ return
129
+
130
+ relayout_data: dict = event["new"]["relayout_data"]
131
+
132
+ if relayout_data.get("selections", None):
133
+ selected_points = self.selected_points
134
+ if selected_points.empty:
135
+ selected_points = None
136
+ self.broker.publish("dplt_selected_points_change", selected_points)
137
+
138
+ with self.g.batch_update():
139
+ self.g.update_traces(selectedpoints=None)
140
+ self.g.plotly_relayout({"selections": None})
141
+
142
+ return
143
+
144
+ if dragmode := relayout_data.get("dragmode", None):
145
+ ndims = self.controls.n_dimensions.value
146
+ self.state[f"{ndims}dragmode"] = dragmode
147
+ return
148
+
149
+ if shapes := relayout_data.get("shapes", None):
150
+ self.broker.publish("dplt_layout_shapes_change", shapes=shapes)
151
+ return
152
+
153
+ self.g.observe(relayout_publisher, names="_js2py_relayout")
154
+
155
+ def _init_controls(self):
156
+ self.controls = PlotterControls(self.broker)
157
+
158
+ def load_dataset(self, dataset: SCLabDataset | None):
159
+ if dataset is None:
160
+ dataset = SCLabDataset(self.broker)
161
+
162
+ if not isinstance(dataset, SCLabDataset):
163
+ raise TypeError("dataset must be an instance of SCLabDataset")
164
+
165
+ if dataset.broker.id != self.broker.id:
166
+ raise ValueError("dataset broker must be the same as the provided broker")
167
+
168
+ self.dataset = dataset
169
+
170
+ self.broker.publish("dplt_dataset_change", dataset)
171
+
172
+ @property
173
+ def selected_points(self):
174
+ selectedpoints = set()
175
+ for data in self.g.data:
176
+ if isinstance(data, go.Contour) or not data.selectedpoints:
177
+ continue
178
+ ids = data.hovertext[list(data.selectedpoints)]
179
+ selectedpoints = selectedpoints.union(ids)
180
+ return pd.Index(selectedpoints)
181
+
182
+ def select_points(self, points: pd.Index | None):
183
+ if self.controls.n_dimensions.value != "2D":
184
+ self.make_new_figure()
185
+ return
186
+ else:
187
+ self.update_marker_sizes()
188
+ return
189
+
190
+ @property
191
+ def data_for_plot(self):
192
+ # TODO: Needs refactoring. Define clear behavior
193
+
194
+ if not self.controls.rotate_steps.value:
195
+ return self.dataset.data
196
+
197
+ ndims_avail = self.dataset.data.shape[1]
198
+ col_x: str = self.controls.selected_axes.children[0].value
199
+ col_y: str = self.controls.selected_axes.children[1].value
200
+ col_z: str = self.controls.selected_axes.children[2].value
201
+
202
+ if ndims_avail < 3 or not col_z:
203
+ return self.dataset.data
204
+
205
+ X = self.dataset.data[[col_x, col_y, col_z]]
206
+ return rotate_multiple_steps(X, self.controls.rotate_steps.value)
207
+
208
+ def make_new_figure(
209
+ self,
210
+ metadata: pd.DataFrame | None = None,
211
+ colorby: str | None = None,
212
+ sizeby: str | None = None,
213
+ marker_size_scale: float | None = None,
214
+ new_value: str | None = None,
215
+ *args,
216
+ **figure_kwargs,
217
+ ):
218
+ self.g.layout.annotations = []
219
+
220
+ ndims = self.controls.n_dimensions.value
221
+ col_x: str = self.controls.selected_axes.children[0].value
222
+ col_y: str = self.controls.selected_axes.children[1].value
223
+ col_z: str = self.controls.selected_axes.children[2].value
224
+
225
+ invalid_axes = col_x == "rank" and not col_y
226
+
227
+ if self.dataset.data.empty or invalid_axes:
228
+ layout = dict(
229
+ xaxis_title="", yaxis_title="", title="No Data", template="simple_white"
230
+ )
231
+ self.g.update(dict(layout=layout, data=[]), overwrite=True)
232
+ return
233
+
234
+ data = self.data_for_plot
235
+ if metadata is None:
236
+ metadata = self.dataset.metadata
237
+
238
+ df = data.join(metadata.loc[:, ~metadata.columns.isin(data.columns)])
239
+ if col_x == "rank":
240
+ df = df.sort_values(col_y, ascending=False, na_position="last")
241
+ df["rank"] = np.arange(df.shape[0]) + 1
242
+
243
+ log_x: bool = self.controls.log_axes.children[0].value
244
+ log_y: bool = self.controls.log_axes.children[1].value
245
+ log_z: bool = self.controls.log_axes.children[2].value
246
+
247
+ if log_x and ndims in ["1D", "2D", "3D"]:
248
+ df = df[df[col_x] > 0]
249
+ df["log " + col_x] = df[col_x].apply(np.log10)
250
+ col_x = "log " + col_x
251
+
252
+ if log_y and ndims in ["2D", "3D"]:
253
+ df = df[df[col_y] > 0]
254
+ df["log " + col_y] = df[col_y].apply(np.log10)
255
+ col_y = "log " + col_y
256
+
257
+ if log_z and ndims == "3D":
258
+ df = df[df[col_z] > 0]
259
+ df["log " + col_z] = df[col_z].apply(np.log10)
260
+ col_z = "log " + col_z
261
+
262
+ log_x = log_y = log_z = False
263
+
264
+ if ndims in ["1D", "2D", "3D"]:
265
+ x = df[col_x]
266
+ dx = x.max() - x.min()
267
+
268
+ if ndims in ["2D", "3D"]:
269
+ y = df[col_y]
270
+ dy = y.max() - y.min()
271
+
272
+ if ndims == "3D":
273
+ z = df[col_z]
274
+ dz = z.max() - z.min()
275
+
276
+ selected = df["is_selected"]
277
+ selection_is_active = not selected.isna().all()
278
+ if selection_is_active and (ndims == "1D" or ndims == "3D"):
279
+ df = df.loc[selected]
280
+
281
+ if colorby is None:
282
+ colorby = self.controls.color.value
283
+
284
+ if colorby:
285
+ ascending = is_any_real_numeric_dtype(df[colorby])
286
+ df = df.sort_values(colorby, ascending=ascending, na_position="first")
287
+
288
+ series: pd.Series = df[colorby]
289
+ if isinstance(series.dtype, pd.CategoricalDtype | bool):
290
+ df[colorby] = series.astype(str).replace("nan", "NA")
291
+ elif ndims == "1D":
292
+ # break into 10 evenly distributed bins
293
+ df[colorby] = pd.cut(series, 10).astype(str).replace("nan", "NA")
294
+
295
+ if ndims == "1D":
296
+ fig = px.histogram(
297
+ df,
298
+ x=col_x,
299
+ log_x=log_x,
300
+ color=colorby,
301
+ color_discrete_sequence=COLOR_DISCRETE_SEQUENCE,
302
+ template="simple_white",
303
+ nbins=self.controls.histogram_nbins.value,
304
+ **figure_kwargs,
305
+ )
306
+ elif ndims == "2D":
307
+ fig = px.scatter(
308
+ df,
309
+ x=col_x,
310
+ y=col_y,
311
+ log_x=log_x,
312
+ log_y=log_y,
313
+ color=colorby,
314
+ color_discrete_sequence=COLOR_DISCRETE_SEQUENCE,
315
+ hover_name=df.index,
316
+ template="simple_white",
317
+ render_mode="webgl",
318
+ **figure_kwargs,
319
+ )
320
+ elif ndims == "3D":
321
+ fig = px.scatter_3d(
322
+ df,
323
+ x=col_x,
324
+ y=col_y,
325
+ z=col_z,
326
+ log_x=log_x,
327
+ log_y=log_y,
328
+ log_z=log_z,
329
+ color=colorby,
330
+ color_discrete_sequence=COLOR_DISCRETE_SEQUENCE,
331
+ hover_name=df.index,
332
+ template="simple_white",
333
+ **figure_kwargs,
334
+ )
335
+ else:
336
+ layout = dict(
337
+ xaxis_title="",
338
+ yaxis_title="",
339
+ title="No Data",
340
+ template="simple_white",
341
+ height=self.controls.plot_height.value,
342
+ modebar=self.modebar,
343
+ dragmode=False,
344
+ )
345
+ self.g.update(dict(layout=layout, data=[]), overwrite=True)
346
+ return
347
+
348
+ fig.update_layout(legend_title_text="")
349
+ fig.update_layout(coloraxis_colorbar_title_text="")
350
+
351
+ fig.update_traces(marker_color="lightgray", selector={"name": "NA"})
352
+ if colorby and is_bool_dtype(series):
353
+ fig.update_traces(marker_color="lightgray", selector={"name": "False"})
354
+ fig.update_traces(
355
+ marker_color=COLOR_DISCRETE_SEQUENCE[0], selector={"name": "True"}
356
+ )
357
+ fig.data = fig.data[::-1]
358
+
359
+ if colorby and isinstance(series.dtype, pd.CategoricalDtype):
360
+ color_pallete_size = len(COLOR_DISCRETE_SEQUENCE)
361
+ for i, cat in enumerate(series.cat.categories):
362
+ fig.update_traces(
363
+ marker_color=COLOR_DISCRETE_SEQUENCE[i % color_pallete_size],
364
+ selector={"name": cat},
365
+ )
366
+
367
+ if ndims == "2D" or ndims == "3D":
368
+ # trace_opacity = self.controls.marker_opacity.value
369
+ # fig.update_traces(opacity=trace_opacity)
370
+
371
+ fig.update_layout(legend_traceorder="reversed")
372
+ if col_x == "rank" or not self.controls.enable_hover_info.value:
373
+ fig.update_traces(hoverinfo="skip", hovertemplate=None)
374
+
375
+ else:
376
+ fig.data = fig.data[::-1]
377
+ fig.update_layout(legend_traceorder="normal")
378
+
379
+ if ndims == "2D" and self.controls.show_density.value:
380
+ # make density plot
381
+ data = df.sort_index()[[col_x, col_y]].values
382
+ data = tuple(tuple(row) for row in data)
383
+ grid_resolution = self.controls.density_grid_resolution.value
384
+ line_smoothing = self.controls.density_line_smoothing.value
385
+ bandwidth_factor = self.controls.density_bandwidth_factor.value
386
+ contours = self.controls.density_contours.value
387
+ trace = _make_density_heatmap(
388
+ data,
389
+ bandwidth_factor,
390
+ grid_resolution,
391
+ line_smoothing,
392
+ contours,
393
+ "rgba(255, 255, 255, 0)",
394
+ )
395
+ fig.add_trace(trace)
396
+
397
+ height = self.controls.plot_height.value
398
+ fig.update_layout(height=height)
399
+
400
+ if self.controls.aspect_equal.value:
401
+ if ndims == "3D":
402
+ fig.update_layout(scene_aspectmode="data")
403
+ elif ndims == "2D":
404
+ fig.update_xaxes(scaleanchor="y", scaleratio=1)
405
+ fig.update_yaxes(scaleanchor="x", scaleratio=1)
406
+
407
+ # fix ranges
408
+ if ndims == "1D":
409
+ fig.update_layout(xaxis_range=[x.min() - dx * 0.1, x.max() + dx * 0.1])
410
+
411
+ elif ndims == "2D":
412
+ fig.update_layout(xaxis_range=[x.min() - dx * 0.1, x.max() + dx * 0.1])
413
+ fig.update_layout(yaxis_range=[y.min() - dy * 0.1, y.max() + dy * 0.1])
414
+
415
+ elif ndims == "3D":
416
+ fig.update_layout(
417
+ scene_xaxis_range=[x.min() - dx * 0.1, x.max() + dx * 0.1]
418
+ )
419
+ fig.update_layout(
420
+ scene_yaxis_range=[y.min() - dy * 0.1, y.max() + dy * 0.1]
421
+ )
422
+ fig.update_layout(
423
+ scene_zaxis_range=[z.min() - dz * 0.1, z.max() + dz * 0.1]
424
+ )
425
+
426
+ # set dragmode
427
+ if not (dragmode := self.get_dragmode()):
428
+ if ndims == "3D":
429
+ dragmode = "turntable"
430
+ elif ndims == "2D":
431
+ dragmode = "lasso"
432
+ else:
433
+ dragmode = False
434
+
435
+ fig.update_layout(title=self.dataset._selected_data_key, modebar=self.modebar)
436
+ fig.update_layout(template_layout_shapedefaults_fillcolor="rgba(0, 0, 0, 0)")
437
+ fig.update_layout(legend={"itemsizing": "constant"})
438
+
439
+ if colorby:
440
+ fig.update_layout(showlegend=True)
441
+ else:
442
+ fig.update_layout(showlegend=False)
443
+
444
+ with self.g.batch_update():
445
+ self.g.update(fig.to_dict(), overwrite=True)
446
+ self.update_marker_sizes(
447
+ metadata=metadata,
448
+ colorby=colorby,
449
+ sizeby=sizeby,
450
+ marker_size_scale=marker_size_scale,
451
+ )
452
+ self.set_dragmode(dragmode)
453
+
454
+ def set_dragmode(self, dragmode: str | None = None):
455
+ ndims = self.controls.n_dimensions.value
456
+ if dragmode is None:
457
+ dragmode = self.state.get(f"{ndims}dragmode", False)
458
+
459
+ self.g.plotly_relayout({"dragmode": dragmode})
460
+ self.state[f"{ndims}dragmode"] = dragmode
461
+
462
+ def get_dragmode(self) -> str:
463
+ ndims = self.controls.n_dimensions.value
464
+ return self.state.get(f"{ndims}dragmode", False)
465
+
466
+ def set_shapes(self, shapes: list[dict]):
467
+ self.g.layout.shapes = shapes
468
+ self.g._send_relayout_msg({"shapes": shapes})
469
+ self.broker.publish("dplt_layout_shapes_change", shapes=shapes)
470
+
471
+ def clear_shapes(self):
472
+ self.set_shapes([])
473
+
474
+ def update_marker_sizes(
475
+ self,
476
+ marker_size_scale: float | None = None,
477
+ sizeby: str | None = None,
478
+ colorby: str | None = None,
479
+ metadata: pd.DataFrame | None = None,
480
+ ):
481
+ if colorby is None:
482
+ colorby = self.controls.color.value
483
+
484
+ if marker_size_scale is None:
485
+ marker_size_scale = self.controls.marker_size.value
486
+
487
+ if sizeby is None:
488
+ sizeby = self.controls.sizeby.value
489
+
490
+ if metadata is None:
491
+ df = self.dataset.metadata
492
+ else:
493
+ df = metadata
494
+
495
+ sizeby_series = pd.Series(1.0, index=df.index)
496
+ marker_sizeref = 1.0 / marker_size_scale**2
497
+
498
+ # is_selected may be a boolean column or a column of NaNs
499
+ if active_selection := (not (selected := df["is_selected"]).isna().all()):
500
+ # if it is a boolean column, a selection has been defined (possible all False)
501
+ sizeby_series.loc[selected] = 3.0
502
+ sizeby_series.loc[~selected] = 0.5
503
+
504
+ elif sizeby is not None:
505
+ sizeby_series = df[sizeby].astype(float)
506
+ sizeby_series.loc[sizeby_series < 0] = 0.0
507
+ sizeby_series = sizeby_series.fillna(0.0)
508
+ marker_sizeref = sizeby_series.max() / marker_size_scale**2
509
+
510
+ trace: go.Scatter | go.Scattergl | go.Scatter3d
511
+ for trace in self.g.data:
512
+ marker_ids = trace.hovertext
513
+ if not isinstance(marker_ids, np.ndarray | list):
514
+ continue
515
+
516
+ if df.index.intersection(marker_ids).empty:
517
+ continue
518
+
519
+ trace.hovertemplate = self.get_hovertemplate(
520
+ info={colorby: trace.name}, show_size=not active_selection
521
+ )
522
+
523
+ marker: Marker2D | Marker3D = trace.marker
524
+ marker.sizemode = "area"
525
+ marker.sizeref = marker_sizeref
526
+ marker.size = sizeby_series.loc[marker_ids].values
527
+ marker.line.width = 0.0
528
+
529
+ def get_hovertemplate(self, info: dict = {}, show_size: bool = True) -> str:
530
+ ndims = self.controls.n_dimensions.value
531
+ x_axis = self.controls.selected_axes.children[0].value
532
+ y_axis = self.controls.selected_axes.children[1].value
533
+ z_axis = self.controls.selected_axes.children[2].value
534
+ marker_color: str | None = self.controls.color.value
535
+ marker_size: str | None = self.controls.sizeby.value
536
+
537
+ if marker_color is not None:
538
+ series = self.dataset.metadata[marker_color]
539
+ marker_color_is_cat = isinstance(series.dtype, pd.CategoricalDtype)
540
+ else:
541
+ marker_color_is_cat = False
542
+
543
+ is_histogram = ndims == "1D"
544
+ is_scatter = ndims in ["2D", "3D"]
545
+ is_3d_scatter = ndims == "3D"
546
+ is_na = ndims == "NA"
547
+
548
+ hovertemplate = ""
549
+ if is_scatter | is_na:
550
+ hovertemplate = "<b>%{hovertext}</b><br>"
551
+
552
+ if is_scatter and marker_color and not marker_color_is_cat:
553
+ info.pop(marker_color, None)
554
+ hovertemplate += f"{marker_color} = %{{marker.color}}<br>"
555
+
556
+ if is_scatter and marker_color and marker_color_is_cat:
557
+ trace_name = info.pop(marker_color, None)
558
+ if trace_name is not None:
559
+ hovertemplate += f"{marker_color} = {trace_name}<br>"
560
+
561
+ if is_scatter and marker_size and show_size and marker_size != marker_color:
562
+ hovertemplate += f"{marker_size} = %{{marker.size}}<br>"
563
+
564
+ if is_histogram:
565
+ hovertemplate += "count = %{y}<br>"
566
+
567
+ hovertemplate = hovertemplate[:-4]
568
+ hovertemplate += "<extra>"
569
+
570
+ if not is_na:
571
+ hovertemplate += f"{x_axis} = %{{x}}<br>"
572
+ if is_scatter:
573
+ hovertemplate += f"{y_axis} = %{{y}}<br>"
574
+ if is_3d_scatter:
575
+ hovertemplate += f"{z_axis} = %{{z}}<br>"
576
+
577
+ if info:
578
+ hovertemplate += "<br>"
579
+
580
+ for key, val in info.items():
581
+ if key or val:
582
+ hovertemplate += f"{key} = {val}<br>"
583
+ else:
584
+ hovertemplate = hovertemplate[:-4]
585
+
586
+ hovertemplate += "</extra>"
587
+
588
+ return hovertemplate
589
+
590
+ def dplt_point_click_callback(self, row_name, device_state, **kwargs):
591
+ ndims = self.controls.n_dimensions.value
592
+ if ndims == "1D":
593
+ return
594
+
595
+ for trace in self.g.data:
596
+ if isinstance(trace, go.Contour):
597
+ continue
598
+
599
+ marker_size = trace.marker.size.copy()
600
+ # if ndims == "2D":
601
+ # marker_opacity = trace.marker.opacity.copy()
602
+
603
+ if row_name in trace.hovertext:
604
+ idx = list(trace.hovertext).index(row_name)
605
+
606
+ default_size = self.controls.marker_size.value
607
+ current_size = marker_size[idx]
608
+
609
+ if current_size == default_size:
610
+ marker_size[idx] = default_size * 3
611
+ # if ndims == "2D":
612
+ # marker_opacity[idx] = 1.0
613
+ else:
614
+ marker_size[idx] = default_size
615
+ # if ndims == "2D":
616
+ # marker_opacity[idx] = 0.7
617
+
618
+ trace.marker.size = marker_size
619
+ # if ndims == "2D":
620
+ # with self.g.batch_animate():
621
+ # trace.marker.size = marker_size
622
+ # trace.marker.opacity = marker_opacity
623
+ # else:
624
+ # with self.g.batch_update():
625
+ # trace.marker.size = marker_size
626
+
627
+ break
628
+
629
+ def ctrl_show_density_change_callback(self, new_value):
630
+ if new_value:
631
+ try:
632
+ trace = next(filter(lambda o: isinstance(o, go.Contour), self.g.data))
633
+ trace.visible = True
634
+ except StopIteration:
635
+ self.make_new_figure()
636
+ else:
637
+ try:
638
+ trace = next(filter(lambda o: isinstance(o, go.Contour), self.g.data))
639
+ trace.visible = False
640
+ except StopIteration:
641
+ pass
642
+
643
+ def ctrl_density_line_smoothing_change_callback(self, new_value):
644
+ if self.controls.show_density.value:
645
+ self.g.update_traces(
646
+ selector=dict(type="contour"), line_smoothing=new_value
647
+ )
648
+
649
+ def ctrl_density_contours_change_callback(self, new_value):
650
+ if self.controls.show_density.value:
651
+ trace: go.Contour = next(
652
+ self.g.select_traces(selector=dict(type="contour"))
653
+ )
654
+
655
+ z_values: NDArray = trace.z
656
+ start = z_values.min() + 0.0001
657
+ end = z_values.max() + 0.0001
658
+ size = (end - start) / new_value
659
+ contours = dict(start=start, end=end, size=size)
660
+
661
+ self.g.update_traces(selector=dict(type="contour"), contours=contours)
662
+
663
+ def ctrl_density_grid_resolution_change_callback(self, new_value):
664
+ if self.controls.show_density.value:
665
+ self.make_new_figure()
666
+
667
+ def ctrl_density_bandwidth_factor_change_callback(self, new_value):
668
+ if self.controls.show_density.value:
669
+ self.make_new_figure()
670
+
671
+ def ctrl_marker_size_change_callback(self, new_value):
672
+ ndims = self.controls.n_dimensions.value
673
+ if ndims == "1D":
674
+ return
675
+
676
+ with self.g.batch_update():
677
+ self.update_marker_sizes(marker_size_scale=new_value)
678
+
679
+ def ctrl_sizeby_change_callback(self, new_value):
680
+ ndims = self.controls.n_dimensions.value
681
+ if ndims == "1D":
682
+ return
683
+
684
+ with self.g.batch_update():
685
+ self.update_marker_sizes(sizeby=new_value)
686
+
687
+ def ctrl_marker_opacity_change_callback(self, new_value):
688
+ ndims = self.controls.n_dimensions.value
689
+ if ndims == "1D":
690
+ return
691
+
692
+ with self.g.batch_update():
693
+ self.g.update_traces(selector=dict(type="scatter"), opacity=new_value)
694
+ self.g.update_traces(selector=dict(type="scattergl"), opacity=new_value)
695
+ self.g.update_traces(selector=dict(type="scatter3d"), opacity=new_value)
696
+
697
+ def ctrl_aspect_equal_change_callback(self, new_value):
698
+ ndims = self.controls.n_dimensions.value
699
+ if new_value:
700
+ if ndims == "3D":
701
+ self.g.update_layout(scene_aspectmode="data")
702
+ elif ndims == "2D":
703
+ self.g.update_xaxes(scaleanchor="y", scaleratio=1)
704
+ self.g.update_yaxes(scaleanchor="x", scaleratio=1)
705
+ else:
706
+ if ndims == "3D":
707
+ self.g.update_layout(scene_aspectmode="auto")
708
+ elif ndims == "2D":
709
+ self.g.update_xaxes(scaleanchor=None, scaleratio=None)
710
+ self.g.update_yaxes(scaleanchor=None, scaleratio=None)
711
+
712
+ def ctrl_enable_hover_info_change_callback(self, new_value):
713
+ if new_value:
714
+ colorby = self.controls.color.value
715
+ for trace in self.g.data:
716
+ trace.hoverinfo = "all"
717
+ trace.hovertemplate = self.get_hovertemplate(info={colorby: trace.name})
718
+ else:
719
+ self.g.update_traces(hoverinfo="skip", hovertemplate=None)
720
+
721
+ def ctrl_plot_width_change_callback(self, new_value):
722
+ if new_value > 0:
723
+ self.graph_box.layout.width = f"{new_value}px"
724
+ else:
725
+ self.graph_box.layout.width = "auto"
726
+
727
+ def ctrl_plot_height_change_callback(self, new_value):
728
+ self.graph_box.layout.height = f"{new_value}px"
729
+ self.g.update_layout(height=new_value)
730
+
731
+ def ctrl_save_button_click_callback(self):
732
+ dataset_name = self.dataset.name
733
+ data_key = self.dataset._selected_data_key
734
+ if not data_key:
735
+ return
736
+
737
+ filename = f"{dataset_name}___saved_figure___data_{data_key}"
738
+
739
+ ndims = self.controls.n_dimensions.value
740
+ if ndims == "1D":
741
+ variable = self.controls.selected_axes.children[0].value
742
+ filename += f"___{variable}-histogram"
743
+ elif ndims == "2D":
744
+ x = self.controls.selected_axes.children[0].value
745
+ y = self.controls.selected_axes.children[1].value
746
+ filename += f"___{x}-vs-{y}"
747
+ elif ndims == "3D":
748
+ x = self.controls.selected_axes.children[0].value
749
+ y = self.controls.selected_axes.children[1].value
750
+ z = self.controls.selected_axes.children[2].value
751
+ filename += f"___{x}-vs-{y}-vs-{z}"
752
+
753
+ if self.controls.color.value:
754
+ filename += f"___colorBy_{self.controls.color.value}"
755
+
756
+ filename += ".html"
757
+ self.g.write_html(filename)
758
+
759
+ def dspr_clear_selection_click_callback(self):
760
+ with self.g.batch_update():
761
+ if self.controls.n_dimensions.value == "2D":
762
+ self.g.update_traces(selectedpoints=None)
763
+ self.g.plotly_relayout({"selections": None})
764
+
765
+ def dplt_start_drawing_request_callback(self):
766
+ if "2D" not in self.controls.n_dimensions.options:
767
+ return
768
+
769
+ self.controls.dict["n_dimensions"].value = "2D"
770
+ self.state["drawing"] = True
771
+ self.state["previous_dragmode"] = self.get_dragmode()
772
+ for control_name in [
773
+ "data_key",
774
+ "color",
775
+ "sizeby",
776
+ "n_dimensions",
777
+ "selected_axes_1",
778
+ "selected_axes_2",
779
+ "selected_axes_3",
780
+ "log_axes_1",
781
+ "log_axes_2",
782
+ "log_axes_3",
783
+ ]:
784
+ self.controls.dict[control_name].disabled = True
785
+ self.set_dragmode("drawopenpath")
786
+
787
+ def dplt_end_drawing_request_callback(self):
788
+ if not self.state.get("drawing", False):
789
+ return
790
+
791
+ self.state["drawing"] = False
792
+ self.clear_shapes()
793
+ for control_name in [
794
+ "data_key",
795
+ "color",
796
+ "sizeby",
797
+ "n_dimensions",
798
+ "selected_axes_1",
799
+ "selected_axes_2",
800
+ "selected_axes_3",
801
+ "log_axes_1",
802
+ "log_axes_2",
803
+ "log_axes_3",
804
+ ]:
805
+ self.controls.dict[control_name].disabled = False
806
+ self.set_dragmode(self.state["previous_dragmode"])
807
+
808
+ def dplt_layout_shapes_change_callback(self, shapes: list[dict]):
809
+ shapes = [s for s in shapes if s["name"] != "smooth path"]
810
+
811
+ if not shapes:
812
+ return
813
+
814
+ path = spt.Path()
815
+ for shape in shapes:
816
+ p = spt.parse_path(shape["path"])
817
+ if len(path) > 0:
818
+ path.append(spt.Line(path[-1].end, p[0].start))
819
+ path.extend(p)
820
+
821
+ pi, pf = path[0].start, path[-1].end
822
+ if np.abs(pf - pi) / path.length() < 0.01:
823
+ path[-1].end = pi
824
+
825
+ n = len(path) + 1
826
+ if n < 3:
827
+ return
828
+
829
+ p = np.array([path[0].start] + [line.end for line in path])
830
+ path_is_closed = p[-1] == p[0]
831
+
832
+ if n >= 5:
833
+ X = np.array([[line.start, line.end] for line in path])
834
+ t = np.abs(np.diff(X))
835
+ t = np.insert(np.cumsum(t / t.sum()), 0, 0)
836
+ if path_is_closed:
837
+ px_bspl = make_periodic_smoothing_spline(
838
+ t[:-1], p.real[:-1], t_range=(0, 1), lam=1 / 5e3 / n
839
+ )
840
+ py_bspl = make_periodic_smoothing_spline(
841
+ t[:-1], p.imag[:-1], t_range=(0, 1), lam=1 / 5e3 / n
842
+ )
843
+ t = np.linspace(0, 1, 10 * n) % 1
844
+ else:
845
+ px_bspl = make_smoothing_spline(t, p.real, lam=1 / 1e3 / n)
846
+ py_bspl = make_smoothing_spline(t, p.imag, lam=1 / 1e3 / n)
847
+ t = np.linspace(0, 1, 10 * n)
848
+ x, y = px_bspl(t), py_bspl(t)
849
+
850
+ points = x + 1j * y
851
+ spath = spt.Path()
852
+ spath.extend(
853
+ [spt.Line(start, end) for start, end in zip(points, points[1:])]
854
+ )
855
+
856
+ else:
857
+ spath = path
858
+
859
+ s1 = {
860
+ "editable": False,
861
+ "visible": False,
862
+ "name": "drawn path",
863
+ "showlegend": False,
864
+ "legend": "legend",
865
+ "legendgroup": "",
866
+ "legendgrouptitle": {
867
+ "text": "",
868
+ "font": {"weight": "normal", "style": "normal", "variant": "normal"},
869
+ },
870
+ "legendrank": 1000,
871
+ "label": {"text": "", "texttemplate": ""},
872
+ "xref": "x",
873
+ "yref": "y",
874
+ "layer": "above",
875
+ "opacity": 1,
876
+ "line": {"color": "#444", "width": 4, "dash": "solid"},
877
+ "type": "path",
878
+ "path": path.d(use_closed_attrib=path_is_closed).replace(" ", ""),
879
+ }
880
+ s2 = {
881
+ "editable": False,
882
+ "visible": True,
883
+ "name": "smooth path",
884
+ "showlegend": False,
885
+ "legend": "legend",
886
+ "legendgroup": "",
887
+ "legendgrouptitle": {
888
+ "text": "",
889
+ "font": {"weight": "normal", "style": "normal", "variant": "normal"},
890
+ },
891
+ "legendrank": 1000,
892
+ "label": {"text": "", "texttemplate": ""},
893
+ "xref": "x",
894
+ "yref": "y",
895
+ "layer": "above",
896
+ "opacity": 1,
897
+ "line": {"color": "#444", "width": 4, "dash": "solid"},
898
+ "type": "path",
899
+ "path": spath.d(use_closed_attrib=path_is_closed).replace(" ", ""),
900
+ }
901
+
902
+ self.g._send_relayout_msg({"shapes": (s1, s2)})
903
+
904
+ times = np.linspace(0, 1, 100 * n)
905
+ if n >= 5:
906
+ points = px_bspl(times) + 1j * py_bspl(times)
907
+ else:
908
+ points = np.array([path.point(t) for t in times])
909
+
910
+ col_x: str = self.controls.selected_axes.children[0].value
911
+ col_y: str = self.controls.selected_axes.children[1].value
912
+ x = self.data_for_plot[col_x].values
913
+ y = self.data_for_plot[col_y].values
914
+
915
+ X = x + 1j * y
916
+ P = points
917
+ T = times
918
+
919
+ self.broker.publish(
920
+ "dplt_soft_path_computed",
921
+ time_points=T,
922
+ data_points=X,
923
+ path_points=P,
924
+ path_is_closed=path_is_closed,
925
+ )
926
+
927
+ def dplt_plot_figure_request_callback(
928
+ self,
929
+ figure: go.Figure | None = None,
930
+ metadata: pd.DataFrame | None = None,
931
+ colorby: str | None = None,
932
+ sizeby: str | None = None,
933
+ **figure_kwargs,
934
+ ):
935
+ if figure is None:
936
+ self.make_new_figure(
937
+ metadata=metadata,
938
+ colorby=colorby,
939
+ sizeby=sizeby,
940
+ **figure_kwargs,
941
+ )
942
+ return
943
+
944
+ if colorby in self.controls.color.options:
945
+ self.controls.color.value = colorby
946
+ else:
947
+ self.controls.color.value = None
948
+
949
+ if sizeby in self.controls.sizeby.options:
950
+ self.controls.sizeby.value = sizeby
951
+ else:
952
+ self.controls.sizeby.value = None
953
+
954
+ figure.layout.template = self.g.layout.template
955
+
956
+ with self.g.batch_update():
957
+ self.g.update(figure.to_dict(), overwrite=True)
958
+ self.update_marker_sizes(colorby=colorby, sizeby=sizeby)
959
+ self.g.plotly_relayout({"dragmode": False})
960
+
961
+ def dplt_add_trace_request_callback(self, trace: go.Scatter | go.Scattergl):
962
+ with self.g.batch_update():
963
+ self.g.add_trace(trace)
964
+
965
+ def dplt_add_vline_request_callback(self, vlines: float | list[float], **kwargs):
966
+ if not isinstance(vlines, list):
967
+ vlines = [vlines]
968
+
969
+ with self.g.batch_update():
970
+ for vline in vlines:
971
+ self.g.add_vline(x=vline, **kwargs)
972
+
973
+ def dplt_add_hline_request_callback(self, hlines: float | list[float], **kwargs):
974
+ if not isinstance(hlines, list):
975
+ hlines = [hlines]
976
+
977
+ with self.g.batch_update():
978
+ for hline in hlines:
979
+ self.g.add_hline(y=hline, **kwargs)
980
+
981
+ def dplt_add_data_as_line_trace_request_callback(
982
+ self, data_key: str, data: pd.DataFrame, **kvargs
983
+ ):
984
+ if data_key != self.dataset._selected_data_key:
985
+ return
986
+
987
+ ndims = self.controls.n_dimensions.value
988
+ col_x: str = self.controls.selected_axes.children[0].value
989
+ col_y: str = self.controls.selected_axes.children[1].value
990
+ col_z: str = self.controls.selected_axes.children[2].value
991
+
992
+ data = rotate_multiple_steps(data, self.controls.rotate_steps.value)
993
+
994
+ if ndims == "1D":
995
+ if col_x + "_height" not in data.columns:
996
+ return
997
+
998
+ x = data[col_x].values
999
+ y = data[col_x + "_height"].values
1000
+ trace = go.Scattergl(x=x, y=y, mode="lines", **kvargs)
1001
+
1002
+ elif ndims == "2D":
1003
+ x = data[col_x].values
1004
+ y = data[col_y].values
1005
+ trace = go.Scattergl(x=x, y=y, mode="lines", **kvargs)
1006
+
1007
+ elif ndims == "3D":
1008
+ x = data[col_x].values
1009
+ y = data[col_y].values
1010
+ z = data[col_z].values
1011
+ trace = go.Scatter3d(x=x, y=y, z=z, mode="lines", **kvargs)
1012
+
1013
+ else:
1014
+ return
1015
+
1016
+ with self.g.batch_update():
1017
+ self.g.add_trace(trace)