scatter3d-anywidget 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scatter3d/scatter3d.py ADDED
@@ -0,0 +1,1011 @@
1
+ import os
2
+ from pathlib import Path
3
+ from itertools import cycle, count
4
+ from enum import Enum
5
+ from collections import OrderedDict
6
+ from typing import Any, Callable, Sequence
7
+ import weakref
8
+ import base64
9
+
10
+ import anywidget
11
+ import traitlets
12
+ import numpy
13
+ import pandas
14
+ import narwhals
15
+
16
+
17
+ PACKAGE_DIR = Path(__file__).parent
18
+ JAVASCRIPT_DIR = PACKAGE_DIR / "static"
19
+ PROD_ESM = JAVASCRIPT_DIR / "scatter3d.js"
20
+ DEF_DEV_ESM = "http://127.0.0.1:5173/src/index.ts"
21
+
22
+ FLOAT_TYPE = "<f4"
23
+ FLOAT_TYPE_TS = "float32"
24
+ CATEGORY_CODES_DTYPE = "<u4" # uint32 little-endian
25
+ MISSING_COLOR = (0.6, 0.6, 0.6)
26
+ MISSING_CATEGORY_VALUE = "Unassigned"
27
+
28
+ DARK_GREY = "#111111"
29
+ WHITE = "#ffffff"
30
+ DEFAULT_POINT_SIZE = 0.15
31
+ DEFAULT_AXIS_LABEL_SIZE = 0.2
32
+ TAB20_COLORS_RGB = [
33
+ (0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
34
+ (0.6823529411764706, 0.7803921568627451, 0.9098039215686274),
35
+ (1.0, 0.4980392156862745, 0.054901960784313725),
36
+ (1.0, 0.7333333333333333, 0.47058823529411764),
37
+ (0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
38
+ (0.596078431372549, 0.8745098039215686, 0.5411764705882353),
39
+ (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
40
+ (1.0, 0.596078431372549, 0.5882352941176471),
41
+ (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
42
+ (0.7725490196078432, 0.6901960784313725, 0.8352941176470589),
43
+ (0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
44
+ (0.7686274509803922, 0.611764705882353, 0.5803921568627451),
45
+ (0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
46
+ (0.9686274509803922, 0.7137254901960784, 0.8235294117647058),
47
+ (0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
48
+ (0.7803921568627451, 0.7803921568627451, 0.7803921568627451),
49
+ (0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
50
+ (0.8588235294117647, 0.8588235294117647, 0.5529411764705883),
51
+ (0.09019607843137255, 0.7450980392156863, 0.8117647058823529),
52
+ (0.6196078431372549, 0.8549019607843137, 0.8980392156862745),
53
+ ]
54
+
55
+
56
+ class LabelListErrorResponse(Enum):
57
+ ERROR = "error"
58
+ SET_MISSING = "missing"
59
+
60
+
61
+ def _is_valid_color(color):
62
+ if not isinstance(color, tuple):
63
+ raise ValueError(f"Invalid color, should be tuples with three floats {color}")
64
+ if len(color) != 3:
65
+ raise ValueError(f"Invalid color, should be tuples with three floats {color}")
66
+ for value in color:
67
+ if value < 0 or value > 1:
68
+ raise ValueError(
69
+ f"Invalid color, should be coded as floats from 0 to 1 {color}"
70
+ )
71
+
72
+
73
+ CategoryCallback = Callable[["Category", str], None]
74
+
75
+
76
+ class Category:
77
+ def __init__(
78
+ self,
79
+ values: narwhals.typing.IntoSeriesT,
80
+ label_list=None,
81
+ color_palette: dict[Any, tuple[float, float, float]] | None = None,
82
+ missing_color: tuple[float, float, float] = MISSING_COLOR,
83
+ ):
84
+ self._cb_id_gen = count(1)
85
+ self._callbacks: dict[int, weakref.ReferenceType] = {}
86
+
87
+ self._native_values_dtype = values.dtype
88
+ values = narwhals.from_native(values, series_only=True)
89
+ self._narwhals_values_dtype = values.dtype
90
+ self._name = values.name
91
+ self._values_implementation = values.implementation
92
+
93
+ label_list = self._initialize_label_list(values, label_list)
94
+
95
+ self._label_coding = None
96
+ self._label_coding = self._create_label_coding(label_list)
97
+
98
+ self._encode_values(values)
99
+
100
+ self.create_color_palette(color_palette)
101
+
102
+ _is_valid_color(missing_color)
103
+ self._missing_color = missing_color
104
+
105
+ def subscribe(self, cb: CategoryCallback) -> int:
106
+ cb_id = next(self._cb_id_gen)
107
+ try:
108
+ ref = weakref.WeakMethod(cb) # bound method
109
+ except TypeError:
110
+ ref = weakref.ref(cb) # function
111
+ self._callbacks[cb_id] = ref
112
+ return cb_id
113
+
114
+ def unsubscribe(self, cb_id: int) -> None:
115
+ self._callbacks.pop(cb_id, None)
116
+
117
+ def _notify(self, event: str) -> None:
118
+ dead = []
119
+ for cb_id, ref in self._callbacks.items():
120
+ cb = ref()
121
+ if cb is None:
122
+ dead.append(cb_id)
123
+ else:
124
+ cb(self, event)
125
+ for cb_id in dead:
126
+ self._callbacks.pop(cb_id, None)
127
+
128
+ @staticmethod
129
+ def _get_unique_labels_in_values(values):
130
+ return values.drop_nulls().unique().to_list()
131
+
132
+ def _initialize_label_list(self, values, label_list):
133
+ unique_labels = self._get_unique_labels_in_values(values)
134
+ if label_list is not None:
135
+ missing = set(unique_labels).difference(label_list)
136
+ if missing:
137
+ raise RuntimeError(
138
+ "To initialize the label list we need a label list to include all "
139
+ f"unique values, these are missing: {missing}"
140
+ )
141
+ # Keep user-provided order as-is (do not sort).
142
+ return list(label_list)
143
+ else:
144
+ return sorted(unique_labels)
145
+
146
+ @staticmethod
147
+ def _create_label_coding(label_list):
148
+ label_coding = OrderedDict(
149
+ [(label, idx) for idx, label in enumerate(label_list, start=1)]
150
+ )
151
+ return label_coding
152
+
153
+ def _encode_values(self, values):
154
+ coded_values = values.replace_strict(
155
+ self._label_coding, default=0, return_dtype=narwhals.UInt16
156
+ ).to_numpy()
157
+ self._coded_values = coded_values
158
+
159
+ @property
160
+ def values(self):
161
+ coded_values = self._coded_values
162
+ label_coding = self._label_coding
163
+ if label_coding is None:
164
+ raise RuntimeError("label coding should be set, but it is not")
165
+ reverse_coding = {code: label for label, code in label_coding.items()}
166
+
167
+ if self._values_implementation == narwhals.Implementation.PANDAS:
168
+ if pandas.api.types.is_extension_array_dtype(self._native_values_dtype):
169
+ reverse_coding[0] = pandas.NA
170
+ else:
171
+ reverse_coding[0] = None
172
+
173
+ coded_values = pandas.Series(coded_values, name=self.name)
174
+ values = coded_values.replace(reverse_coding).astype(
175
+ self._native_values_dtype
176
+ )
177
+ return values
178
+ else:
179
+ coded_values = narwhals.new_series(
180
+ name=self.name, values=coded_values, backend=self._values_implementation
181
+ )
182
+ reverse_coding[0] = None
183
+ values = coded_values.replace_strict(
184
+ reverse_coding, return_dtype=self._narwhals_values_dtype
185
+ )
186
+ return values.to_native()
187
+
188
+ @property
189
+ def name(self) -> str:
190
+ return self._name
191
+
192
+ @property
193
+ def label_list(self) -> list:
194
+ label_coding = self._label_coding
195
+ if label_coding is None:
196
+ raise RuntimeError("label coding should be set, but it is not")
197
+
198
+ return list(label_coding.keys())
199
+
200
+ @staticmethod
201
+ def _get_next_unused_color(used_colors, color_cycle):
202
+ n_colors = len(TAB20_COLORS_RGB)
203
+ n_tried = 0
204
+ while True:
205
+ color = tuple(next(color_cycle))
206
+ n_tried += 1
207
+ if color not in used_colors:
208
+ return color
209
+ if n_tried >= n_colors:
210
+ # TAB20 exhausted: allow repeats
211
+ return color
212
+
213
+ def set_label_list(
214
+ self,
215
+ new_labels: list[str] | list[int],
216
+ on_missing_labels=LabelListErrorResponse.ERROR,
217
+ color_palette: dict[Any, tuple[float, float, float]] | None = None,
218
+ ):
219
+ if not new_labels:
220
+ raise ValueError("No labels given")
221
+
222
+ if new_labels == self.label_list:
223
+ return
224
+
225
+ overrides = color_palette or {}
226
+
227
+ old_label_coding = self._label_coding
228
+ if old_label_coding is None:
229
+ raise RuntimeError(
230
+ "label coding should be set before trying to modify the label list"
231
+ )
232
+ labels_in_values = old_label_coding.keys()
233
+
234
+ labels_to_remove = list(set(labels_in_values).difference(new_labels))
235
+ if len(labels_to_remove) == len(labels_in_values):
236
+ raise ValueError(
237
+ "None of the new labels matches the labels found in the category"
238
+ )
239
+ if on_missing_labels == LabelListErrorResponse.ERROR and labels_to_remove:
240
+ raise ValueError(
241
+ f"Some labels are missing from the list ({labels_to_remove}), but the action set for missing is error"
242
+ )
243
+
244
+ new_label_coding = self._create_label_coding(new_labels)
245
+
246
+ # --- recode values to new codes ---
247
+ old_values = self._coded_values
248
+ new_values = numpy.full_like(self._coded_values, fill_value=0)
249
+ for label, new_code in new_label_coding.items():
250
+ if label in old_label_coding:
251
+ old_code = old_label_coding[label]
252
+ new_values[old_values == old_code] = new_code
253
+ self._coded_values = new_values
254
+ self._label_coding = new_label_coding
255
+
256
+ # --- update palette ---
257
+ old_palette = getattr(self, "_color_palette", {}) or {}
258
+ new_palette: dict[Any, tuple[float, float, float]] = {}
259
+
260
+ # pass 1: overrides > old palette
261
+ for label in new_labels:
262
+ if label in overrides:
263
+ color = overrides[label]
264
+ _is_valid_color(color)
265
+ new_palette[label] = tuple(color)
266
+ elif label in old_palette:
267
+ new_palette[label] = tuple(old_palette[label])
268
+
269
+ # pass 2: assign remaining labels from TAB20 cycle
270
+ color_cycle = cycle(TAB20_COLORS_RGB)
271
+ used_colors = set(new_palette.values())
272
+ for label in new_labels:
273
+ if label in new_palette:
274
+ continue
275
+ color = self._get_next_unused_color(used_colors, color_cycle)
276
+ used_colors.add(color)
277
+ new_palette[label] = color
278
+
279
+ self._color_palette = new_palette
280
+
281
+ self._notify("label_list")
282
+ self._notify("palette")
283
+
284
+ def set_coded_values(
285
+ self,
286
+ coded_values: numpy.ndarray,
287
+ label_list: list[str] | list[int],
288
+ skip_copying_array=False,
289
+ ):
290
+ if not label_list == self.label_list:
291
+ raise ValueError(
292
+ "The label list used to code the new values should match the current one"
293
+ )
294
+
295
+ label_encoding = self._create_label_coding(label_list)
296
+ if self._label_coding != label_encoding:
297
+ raise RuntimeError("The new label encoding wouldn't match the old one")
298
+
299
+ old_coded_values = self.coded_values
300
+ if old_coded_values.shape != coded_values.shape:
301
+ raise ValueError(
302
+ "The new coded values array has a different size than the older one"
303
+ )
304
+ if old_coded_values.dtype != coded_values.dtype:
305
+ raise ValueError(
306
+ "The dtype of the new coding values does not match the one of the old ones"
307
+ )
308
+
309
+ if not skip_copying_array:
310
+ coded_values = coded_values.copy(order="K")
311
+
312
+ self._coded_values = coded_values
313
+ self._notify("coded_values")
314
+
315
+ @property
316
+ def coded_values(self):
317
+ return self._coded_values
318
+
319
+ @property
320
+ def label_coding(self):
321
+ label_coding = self._label_coding
322
+ if label_coding is None:
323
+ raise RuntimeError(
324
+ "label coding should be set before trying to modify the label list"
325
+ )
326
+ return [(label, code) for label, code in label_coding.items()]
327
+
328
+ def create_color_palette(
329
+ self, color_palette: dict[Any, tuple[float, float, float]] | None = None
330
+ ):
331
+ default_colors = cycle(TAB20_COLORS_RGB)
332
+
333
+ palette = {}
334
+ for label in self.label_list:
335
+ if color_palette:
336
+ try:
337
+ color = color_palette[label]
338
+ _is_valid_color(color)
339
+ except KeyError:
340
+ raise KeyError(
341
+ f"Color palette given, but color missing for label: {label}"
342
+ )
343
+ else:
344
+ color = next(default_colors)
345
+ palette[label] = tuple(color)
346
+ self._color_palette = palette
347
+ self._notify("palette")
348
+
349
+ @property
350
+ def color_palette(self):
351
+ return self._color_palette.copy()
352
+
353
+ @property
354
+ def color_palette_for_codes(self):
355
+ palette = self.color_palette
356
+
357
+ return {code: palette[label] for label, code in self.label_coding}
358
+
359
+ @property
360
+ def missing_color(self):
361
+ return self._missing_color
362
+
363
+ @property
364
+ def num_values(self):
365
+ return self.coded_values.size
366
+
367
+ @property
368
+ def num_unassigned(self) -> int:
369
+ """Number of values unassigned / missing."""
370
+ coded = self._coded_values
371
+ return int(numpy.count_nonzero(coded == 0))
372
+
373
+
374
+ def _esm_source() -> str | Path:
375
+ if os.environ.get("ANY_SCATTER3D_DEV", ""):
376
+ return os.environ.get("ANY_SCATTER3D_DEV_URL", DEF_DEV_ESM)
377
+ return PROD_ESM
378
+
379
+
380
+ def _is_missing(value: object) -> bool:
381
+ if value is None or value is pandas.NA:
382
+ return True
383
+ try:
384
+ # True for float('nan') and pandas NA scalars
385
+ return bool(pandas.isna(value))
386
+ except Exception:
387
+ return False
388
+
389
+
390
+ class Scatter3dWidget(anywidget.AnyWidget):
391
+ _esm = _esm_source()
392
+
393
+ # xyz coords for the points
394
+ # Packed float32 array of shape (N, 3), row-major.
395
+ # TS interprets as Float32Array with length 3*N.
396
+ xyz_bytes_t = traitlets.Bytes(
397
+ default_value=b"",
398
+ help="Packed float32 Nx3, row-major.",
399
+ ).tag(sync=True)
400
+
401
+ # Packed uint16 array of length N.
402
+ # Code 0 means "missing / unassigned".
403
+ # Codes 1..K correspond to labels_t[0..K-1].
404
+ coded_values_t = traitlets.Bytes(
405
+ default_value=b"",
406
+ help="Packed uint16 length N. 0=missing, 1..K correspond to labels_t.",
407
+ ).tag(sync=True)
408
+
409
+ # List[str] of length K, stable ordering.
410
+ # labels_t[i] corresponds to code (i+1).
411
+ labels_t = traitlets.List(
412
+ traitlets.Unicode(),
413
+ default_value=[],
414
+ help="Label list (length K), where code = index+1.",
415
+ ).tag(sync=True)
416
+
417
+ # List[[r,g,b]] of length K, aligned with labels_t.
418
+ # Each component is float in [0,1].
419
+ colors_t = traitlets.List(
420
+ traitlets.List(traitlets.Float(), minlen=3, maxlen=3),
421
+ default_value=[],
422
+ help="Per-label RGB colors (length K) aligned with labels_t; floats in [0,1].",
423
+ ).tag(sync=True)
424
+
425
+ # [r,g,b] used when coded value is 0 or otherwise missing
426
+ missing_color_t = traitlets.List(
427
+ traitlets.Float(),
428
+ default_value=[0.6, 0.6, 0.6],
429
+ minlen=3,
430
+ maxlen=3,
431
+ help="RGB color for missing/unassigned (code 0).",
432
+ ).tag(sync=True)
433
+
434
+ # --- lasso round-trip channels ---
435
+ # Dict message TS -> Python describing a committed lasso operation.
436
+ lasso_request_t = traitlets.Dict(default_value={}).tag(sync=True)
437
+ # Packed uint8 bitmask (binary).
438
+ lasso_mask_t = traitlets.Bytes(default_value=b"").tag(sync=True)
439
+ # Dict message Python -> TS acknowledging the last request (ok/error).
440
+ lasso_result_t = traitlets.Dict(default_value={}).tag(sync=True)
441
+
442
+ point_size_t = traitlets.Float(
443
+ default_value=DEFAULT_POINT_SIZE,
444
+ help="Point size for rendering (three.js PointsMaterial.size).",
445
+ ).tag(sync=True)
446
+
447
+ axis_label_size_t = traitlets.Float(
448
+ default_value=DEFAULT_AXIS_LABEL_SIZE,
449
+ ).tag(sync=True)
450
+
451
+ show_axes_t = traitlets.Bool(
452
+ default_value=True,
453
+ help=("Whether to draw axis lines (X, Y, Z) from the origin (0,0,0)."),
454
+ ).tag(sync=True)
455
+
456
+ tooltip_request_t = traitlets.Dict(default_value={}).tag(sync=True)
457
+ tooltip_response_t = traitlets.Dict(default_value={}).tag(sync=True)
458
+
459
+ client_ready_t = traitlets.Bool(
460
+ default_value=False,
461
+ help="Set True by the frontend once JS is initialized and can talk to Python.",
462
+ ).tag(sync=True)
463
+
464
+ interactive_ready_t = traitlets.Bool(
465
+ default_value=False,
466
+ help="True once frontend has announced readiness; used to gate UI features.",
467
+ ).tag(sync=True)
468
+
469
+ interaction_mode_t = traitlets.Unicode(
470
+ default_value="rotate",
471
+ help="Interaction mode: 'rotate' or 'lasso'.",
472
+ ).tag(sync=True)
473
+
474
+ # Active category label (string). In lasso mode it must always be non-empty and valid.
475
+ active_category_t = traitlets.Unicode(
476
+ default_value=None,
477
+ allow_none=True,
478
+ help="Active category label (str) or None. In lasso mode it must be a valid label from labels_t.",
479
+ ).tag(sync=True)
480
+
481
+ # Legend placement (split into two validated strings; easier to validate than a Dict schema)
482
+ legend_side_t = traitlets.Unicode(
483
+ default_value="right",
484
+ help="Legend side: 'left' or 'right'.",
485
+ ).tag(sync=True)
486
+
487
+ legend_dock_t = traitlets.Unicode(
488
+ default_value="top",
489
+ help="Legend dock: 'top' or 'bottom'.",
490
+ ).tag(sync=True)
491
+
492
+ def __init__(
493
+ self,
494
+ xyz: numpy.ndarray,
495
+ category: Category,
496
+ point_ids: Sequence[str]
497
+ | Sequence[int]
498
+ | narwhals.typing.IntoSeriesT
499
+ | None = None,
500
+ ):
501
+ super().__init__()
502
+ self._category_cb_id: int | None = None
503
+
504
+ if category is not None and xyz.shape[0] != category.num_values:
505
+ raise ValueError(
506
+ f"The number of points ({xyz.shape[0]}) should match "
507
+ f"the number of values in the category: {category.num_values}"
508
+ )
509
+
510
+ if point_ids is not None and xyz.shape[0] != len(point_ids):
511
+ raise ValueError(
512
+ f"The number of points ({xyz.shape[0]}) should match "
513
+ f"the number of values in the category: {category.num_values}"
514
+ )
515
+
516
+ # Keep a stable callback object so unsubscribe works.
517
+ self._category_cb = self._on_category_changed
518
+
519
+ self._xyz = None
520
+ self._category = None
521
+ self.xyz = xyz
522
+
523
+ self._set_default_sizes()
524
+
525
+ self.category = category
526
+
527
+ self.point_ids = self._normalize_point_ids(point_ids)
528
+
529
+ # clear tooltip state
530
+ self.tooltip_response_t = {}
531
+
532
+ # Enforce initial invariants (rotate default allows empty active category)
533
+ self._ensure_active_category_invariants()
534
+
535
+ @traitlets.validate("interaction_mode_t")
536
+ def _validate_interaction_mode_t(self, proposal):
537
+ v = proposal["value"]
538
+ if v not in ("rotate", "lasso"):
539
+ raise traitlets.TraitError("interaction_mode_t must be 'rotate' or 'lasso'")
540
+ return v
541
+
542
+ @traitlets.validate("legend_side_t")
543
+ def _validate_legend_side_t(self, proposal):
544
+ v = proposal["value"]
545
+ if v not in ("left", "right"):
546
+ raise traitlets.TraitError("legend_side_t must be 'left' or 'right'")
547
+ return v
548
+
549
+ @traitlets.validate("legend_dock_t")
550
+ def _validate_legend_dock_t(self, proposal):
551
+ v = proposal["value"]
552
+ if v not in ("top", "bottom"):
553
+ raise traitlets.TraitError("legend_dock_t must be 'top' or 'bottom'")
554
+ return v
555
+
556
+ def _ensure_active_category_invariants(self) -> None:
557
+ """
558
+ Enforce invariants for interaction_mode_t and active_category_t.
559
+ - In lasso mode, active_category_t must be a valid label and non-empty.
560
+ If empty/invalid, set deterministically to first label in labels_t.
561
+ - In rotate mode, empty active_category_t is allowed.
562
+ If non-empty but invalid, raise (no silent fallback).
563
+ """
564
+ mode = self.interaction_mode_t
565
+ labels = list(self.labels_t or [])
566
+
567
+ if mode == "lasso":
568
+ if not labels:
569
+ # Professional behavior: lasso mode without categories is a hard error.
570
+ raise RuntimeError("Cannot enter lasso mode: labels_t is empty")
571
+
572
+ if self.active_category_t is not None and self.active_category_t in labels:
573
+ return
574
+
575
+ # Required deterministic behavior: choose the first category label.
576
+ self.active_category_t = labels[0]
577
+ if self.interactive_ready_t:
578
+ self.send_state("active_category_t")
579
+ return
580
+
581
+ # rotate mode
582
+ if self.active_category_t is None:
583
+ return
584
+ if self.active_category_t not in labels:
585
+ raise RuntimeError(
586
+ f"active_category_t={self.active_category_t!r} is not present in labels_t"
587
+ )
588
+
589
+ @traitlets.observe("interaction_mode_t")
590
+ def _on_interaction_mode_t(self, change) -> None:
591
+ if change.get("new") == "lasso":
592
+ self._ensure_active_category_invariants()
593
+
594
+ @traitlets.observe("active_category_t")
595
+ def _on_active_category_t(self, change) -> None:
596
+ if self.interaction_mode_t == "lasso" and change.get("new") is None:
597
+ # NOTE: traitlets bypasses @validate(...) when allow_none=True.
598
+ # In lasso mode, clearing is a hard error (no silent fallback).
599
+ old = change.get("old")
600
+ if old is not None:
601
+ # restore previous valid value to keep state consistent
602
+ self.set_trait("active_category_t", old)
603
+ raise traitlets.TraitError("active_category_t cannot be None in lasso mode")
604
+
605
+ @traitlets.observe("labels_t")
606
+ def _on_labels_t(self, change) -> None:
607
+ if self.interaction_mode_t == "lasso":
608
+ self._ensure_active_category_invariants()
609
+ elif self.active_category_t is not None and self.active_category_t not in (
610
+ change.get("new") or []
611
+ ):
612
+ # in rotate mode, invalid active is a real error
613
+ raise RuntimeError(...)
614
+
615
+ @traitlets.observe("client_ready_t")
616
+ def _on_client_ready_t(self, change) -> None:
617
+ # Only ever transition False -> True
618
+ if bool(change.get("new")) is True and self.interactive_ready_t is not True:
619
+ self.interactive_ready_t = True
620
+ self.send_state("interactive_ready_t")
621
+
622
+ def _set_default_sizes(self):
623
+ # If there are no points, keep deterministic defaults.
624
+ # This must never crash widget construction.
625
+ if self.num_points == 0:
626
+ self.point_size_t = float(DEFAULT_POINT_SIZE)
627
+ self.axis_label_size_t = float(DEFAULT_AXIS_LABEL_SIZE)
628
+ return
629
+
630
+ max_abs = float(numpy.abs(self.xyz).max())
631
+ self.point_size_t = max_abs / 20.0
632
+ self.axis_label_size_t = max_abs / 5.0
633
+
634
+ def _normalize_point_ids(self, point_ids):
635
+ num_points = self.num_points
636
+
637
+ if point_ids is None:
638
+ point_ids = tuple(range(1, num_points + 1))
639
+
640
+ elif isinstance(point_ids, narwhals.Series):
641
+ point_ids = tuple(point_ids.to_list())
642
+
643
+ elif isinstance(point_ids, Sequence) and not isinstance(
644
+ point_ids, (str, bytes)
645
+ ):
646
+ point_ids = tuple(point_ids)
647
+ else:
648
+ raise TypeError("point_ids must be a Series, a sequence of values, or None")
649
+
650
+ if len(point_ids) != num_points:
651
+ raise ValueError("point_ids length must match number of points")
652
+
653
+ return point_ids
654
+
655
+ def _category_label_for_index(self, idx: int) -> str | None:
656
+ if self._category is None:
657
+ return None
658
+ coded = self._category.coded_values
659
+ code = int(coded[idx])
660
+ if code <= 0:
661
+ return None
662
+ # label_list is 0-based, codes are 1..K
663
+ return str(self._category.label_list[code - 1])
664
+
665
+ @traitlets.validate("active_category_t")
666
+ def _validate_active_category_t(self, proposal):
667
+ v = proposal["value"]
668
+
669
+ if v is None:
670
+ # allowed only in rotate mode
671
+ mode = getattr(self, "interaction_mode_t", "rotate")
672
+ if mode == "lasso":
673
+ raise traitlets.TraitError(
674
+ "active_category_t cannot be None in lasso mode"
675
+ )
676
+ return None
677
+
678
+ if not isinstance(v, str):
679
+ raise traitlets.TraitError("active_category_t must be a string or None")
680
+
681
+ if v == "":
682
+ # reject empty string entirely; it was legacy sentinel
683
+ mode = getattr(self, "interaction_mode_t", "rotate")
684
+ raise traitlets.TraitError(
685
+ "active_category_t must be None (rotate) or a valid label string (rotate/lasso); empty string is not allowed"
686
+ )
687
+
688
+ labels = list(self.labels_t or [])
689
+ mode = getattr(self, "interaction_mode_t", "rotate")
690
+
691
+ if mode == "lasso":
692
+ if not labels:
693
+ raise traitlets.TraitError(
694
+ "Cannot set active_category_t: labels_t is empty"
695
+ )
696
+ if v not in labels:
697
+ raise traitlets.TraitError(
698
+ f"active_category_t={v!r} is not present in labels_t"
699
+ )
700
+ return v
701
+
702
+ # rotate mode
703
+ if labels and v not in labels:
704
+ raise traitlets.TraitError(
705
+ f"active_category_t={v!r} is not present in labels_t"
706
+ )
707
+ return v
708
+
709
+ @traitlets.observe("tooltip_request_t")
710
+ def _on_tooltip_request(self, change) -> None:
711
+ req = change["new"] or {}
712
+ if not isinstance(req, dict):
713
+ return
714
+ if req.get("kind") != "tooltip":
715
+ return
716
+
717
+ request_id = int(req.get("request_id", 0) or 0)
718
+ i = req.get("i", None)
719
+
720
+ try:
721
+ i = int(i)
722
+ if i < 0 or i >= self.num_points:
723
+ raise IndexError(f"point index out of range: {i}")
724
+
725
+ data = {
726
+ "idx": i,
727
+ "id": str(self.point_ids[i]),
728
+ "category": self._category_label_for_index(i),
729
+ }
730
+
731
+ self.tooltip_response_t = {
732
+ "request_id": request_id,
733
+ "status": "ok",
734
+ "data": data,
735
+ }
736
+ except Exception as e:
737
+ self.tooltip_response_t = {
738
+ "request_id": request_id,
739
+ "status": "error",
740
+ "message": str(e),
741
+ }
742
+
743
+ # Force comm sync (important for anywidget)
744
+ self.send_state("tooltip_response_t")
745
+
746
+ def _on_category_changed(self, category: Category, event: str) -> None:
747
+ """
748
+ Called when Category mutates.
749
+ """
750
+ # Sanity: ignore stale callbacks (if category replaced)
751
+ if category is not self._category:
752
+ return
753
+ self._sync_traitlets_from_category()
754
+
755
+ @staticmethod
756
+ def _pack_xyz_float32_c(xyz: numpy.ndarray) -> tuple[numpy.ndarray, bytes]:
757
+ """
758
+ Return (xyz_float32_c, packed_bytes).
759
+ - xyz_float32_c: float32, C-contiguous, shape (N,3)
760
+ - packed_bytes: xyz_float32_c.tobytes(order="C")
761
+ """
762
+ if not isinstance(xyz, numpy.ndarray):
763
+ raise ValueError("xyz should be a numpy array")
764
+
765
+ if xyz.ndim != 2 or xyz.shape[1] != 3:
766
+ raise ValueError("xyz should have shape (N, 3)")
767
+
768
+ # Convert dtype to float32 (TS expects Float32Array)
769
+ # Ensure row-major contiguous layout for stable tobytes.
770
+ xyz_f32 = numpy.asarray(xyz, dtype=numpy.float32, order="C")
771
+ if not xyz_f32.flags["C_CONTIGUOUS"]:
772
+ xyz_f32 = numpy.ascontiguousarray(xyz_f32)
773
+
774
+ # Remap (x,y,z) -> (x,z,y) so that "z" in user data becomes "up" (Y) in Three.js
775
+ # avoid mutating caller's array if it was already float32 C
776
+ # xyz_f32 = xyz_f32.copy()
777
+ xyz_f32[:, [1, 2]] = xyz_f32[:, [2, 1]]
778
+
779
+ return xyz_f32, xyz_f32.tobytes(order="C")
780
+
781
+ def _get_xyz(self) -> numpy.ndarray:
782
+ if self._xyz is None:
783
+ raise RuntimeError("xyz has not been set")
784
+ out = self._xyz.copy()
785
+ out[:, [1, 2]] = out[:, [2, 1]]
786
+ return out
787
+
788
+ def _set_xyz(self, xyz: numpy.ndarray) -> None:
789
+ xyz_f32, xyz_bytes = self._pack_xyz_float32_c(xyz)
790
+
791
+ # If category already set, enforce N consistency
792
+ if self._category is not None and xyz_f32.shape[0] != self.category.num_values:
793
+ raise ValueError(
794
+ f"The number of points ({xyz_f32.shape[0]}) should match "
795
+ f"the number of values in the category: {self.category.num_values}"
796
+ )
797
+
798
+ self._xyz = xyz_f32
799
+ self.xyz_bytes_t = xyz_bytes
800
+
801
+ xyz = property(_get_xyz, _set_xyz)
802
+
803
+ @staticmethod
804
+ def _pack_u16_c(arr: numpy.ndarray) -> bytes:
805
+ arr_u16 = numpy.asarray(arr, dtype=numpy.uint16, order="C")
806
+ if not arr_u16.flags["C_CONTIGUOUS"]:
807
+ arr_u16 = numpy.ascontiguousarray(arr_u16)
808
+ return arr_u16.tobytes(order="C")
809
+
810
+ def _sync_traitlets_from_category(self) -> None:
811
+ """
812
+ Push the Category state into synced transport traitlets.
813
+ Assumes self._xyz and self._category are both set and consistent in length.
814
+ """
815
+ if self._category is None:
816
+ raise RuntimeError("The category should be set")
817
+
818
+ cat = self._category
819
+
820
+ # labels_t must be JSON-friendly; enforce str
821
+ labels = [str(lbl) for lbl in cat.label_list]
822
+ self.labels_t = labels
823
+
824
+ # coded values: uint16 bytes, length N
825
+ coded = cat.coded_values
826
+ if coded.shape[0] != self.num_points:
827
+ raise RuntimeError(
828
+ f"Category has {coded.shape[0]} values but xyz has {self.num_points} points"
829
+ )
830
+ self.coded_values_t = self._pack_u16_c(coded)
831
+
832
+ # colors aligned with labels order
833
+ # Category stores palette keyed by original labels; we reconstruct in label_list order.
834
+ palette = cat.color_palette # label -> (r,g,b)
835
+ self.colors_t = [list(map(float, palette[lbl])) for lbl in cat.label_list]
836
+
837
+ # missing color
838
+ self.missing_color_t = list(map(float, cat.missing_color))
839
+
840
+ if len(self.colors_t) != len(self.labels_t):
841
+ raise RuntimeError(
842
+ "Internal error: colors_t length must match labels_t length"
843
+ )
844
+
845
+ self._ensure_active_category_invariants()
846
+
847
+ def _get_category(self):
848
+ return self._category
849
+
850
+ def _set_category(self, category: Category) -> None:
851
+ if self._xyz is not None and category.num_values != self.num_points:
852
+ raise ValueError(
853
+ f"The number of values in the category ({category.num_values}) "
854
+ f"should match the number of points {self.num_points}"
855
+ )
856
+ if self._category is not None and self._category_cb_id is not None:
857
+ self._category.unsubscribe(self._category_cb_id)
858
+
859
+ self._category = category
860
+ # Subscribe to new category
861
+ self._category_cb_id = category.subscribe(self._on_category_changed)
862
+ self._sync_traitlets_from_category()
863
+
864
+ category = property(_get_category, _set_category)
865
+
866
+ @property
867
+ def num_points(self):
868
+ return self.xyz.shape[0]
869
+
870
+ def close(self):
871
+ # detach callback to avoid keeping references around.
872
+ if self._category is not None and self._category_cb_id is not None:
873
+ self._category.unsubscribe(self._category_cb_id)
874
+ self._category_cb_id = None
875
+ super().close()
876
+
877
+ def _label_to_code_map(self) -> dict[str, int]:
878
+ # labels_t[i] -> code i+1
879
+ return {lbl: i + 1 for i, lbl in enumerate(self.labels_t)}
880
+
881
+ def _unpack_mask(self, mask_payload) -> numpy.ndarray:
882
+ n = self.num_points
883
+ needed = (n + 7) // 8
884
+
885
+ if isinstance(mask_payload, (bytes, bytearray, memoryview)):
886
+ mask_bytes = bytes(mask_payload)
887
+ else:
888
+ raise ValueError(
889
+ f"lasso_mask_t must be bytes-like, got {type(mask_payload)}"
890
+ )
891
+
892
+ if len(mask_bytes) < needed:
893
+ raise ValueError(
894
+ f"lasso_mask_t too short: got {len(mask_bytes)} bytes, need {needed} for N={n}"
895
+ )
896
+
897
+ b = numpy.frombuffer(mask_bytes, dtype=numpy.uint8, count=needed)
898
+ bits = numpy.unpackbits(b, bitorder="big")
899
+ return bits[:n].astype(bool, copy=False)
900
+
901
+ def _apply_lasso_mask_edit(self, op: str, code: int, mask: numpy.ndarray) -> int:
902
+ """
903
+ Apply add/remove using a boolean mask of length N.
904
+ Returns number of points actually changed.
905
+ """
906
+ if self._category is None:
907
+ raise RuntimeError("No category set")
908
+
909
+ if mask.dtype != numpy.bool_ or mask.shape != (self.num_points,):
910
+ raise ValueError("Internal error: mask must be bool with shape (N,)")
911
+
912
+ if code < 0 or code > 65535:
913
+ raise ValueError(f"Invalid code {code} (must fit uint16)")
914
+ if code == 0 and op == "add":
915
+ raise ValueError("Cannot add code 0 (reserved for missing/unassigned)")
916
+
917
+ old = self._category.coded_values
918
+ new = old.copy()
919
+
920
+ if op == "add":
921
+ changed = int(numpy.sum(new[mask] != numpy.uint16(code)))
922
+ new[mask] = numpy.uint16(code)
923
+ elif op == "remove":
924
+ # Only remove points currently in that label
925
+ to_zero = mask & (new == numpy.uint16(code))
926
+ changed = int(numpy.sum(to_zero))
927
+ new[to_zero] = numpy.uint16(0)
928
+ else:
929
+ raise ValueError(f"Unknown op: {op!r}")
930
+
931
+ # Update Category (will notify; widget callback syncs coded_values_t etc.)
932
+ self._category.set_coded_values(
933
+ coded_values=new,
934
+ label_list=self._category.label_list,
935
+ skip_copying_array=True,
936
+ )
937
+ return changed
938
+
939
+ @traitlets.observe("lasso_request_t")
940
+ def _on_lasso_request_t(self, change) -> None:
941
+ req = change.get("new", {})
942
+ if not req:
943
+ return
944
+
945
+ request_id = req.get("request_id")
946
+ res: dict[str, object] = {"request_id": request_id}
947
+
948
+ try:
949
+ if req.get("kind") != "lasso_commit":
950
+ raise ValueError(f"Unsupported kind: {req.get('kind')!r}")
951
+
952
+ op = req.get("op")
953
+ if op not in ("add", "remove"):
954
+ raise ValueError(f"Invalid op: {op!r}")
955
+
956
+ # resolve code from either explicit code or label
957
+ if "code" in req and req["code"] is not None:
958
+ code = int(req["code"])
959
+ else:
960
+ label = req.get("label")
961
+ if label is None:
962
+ raise ValueError("Missing field: label (or code)")
963
+ label_s = str(label)
964
+ m = self._label_to_code_map()
965
+ if label_s not in m:
966
+ raise ValueError(f"Unknown label: {label_s!r}")
967
+ code = m[label_s]
968
+
969
+ # unpack mask from bytes traitlet
970
+ if not isinstance(self.lasso_mask_t, (bytes, bytearray, memoryview)):
971
+ raise RuntimeError(
972
+ f"Internal error: lasso_mask_t must be bytes, got {type(self.lasso_mask_t)}"
973
+ )
974
+ mask = self._unpack_mask(self.lasso_mask_t)
975
+ num_selected = int(numpy.sum(mask))
976
+
977
+ changed = self._apply_lasso_mask_edit(op=op, code=code, mask=mask)
978
+
979
+ res.update(
980
+ {
981
+ "status": "ok",
982
+ "num_selected": num_selected,
983
+ "num_changed": changed,
984
+ }
985
+ )
986
+ except Exception as e:
987
+ res.update({"status": "error", "message": str(e)})
988
+
989
+ self.lasso_result_t = res
990
+
991
+ def _get_point_size(self) -> float:
992
+ return float(self.point_size_t)
993
+
994
+ def _set_point_size(self, value: float) -> None:
995
+ v = float(value)
996
+ if not numpy.isfinite(v) or v <= 0:
997
+ raise ValueError("point_size must be a finite positive number")
998
+ self.point_size_t = v
999
+
1000
+ point_size = property(_get_point_size, _set_point_size)
1001
+
1002
+ def _get_axis_label_size(self) -> float:
1003
+ return float(self.axis_label_size_t)
1004
+
1005
+ def _set_axis_label_size(self, value: float) -> None:
1006
+ v = float(value)
1007
+ if not numpy.isfinite(v) or v <= 0:
1008
+ raise ValueError("axes label size must be a finite positive number")
1009
+ self.axis_label_size_t = v
1010
+
1011
+ axis_label_size = property(_get_axis_label_size, _set_axis_label_size)