scatter3d-anywidget 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scatter3d/__init__.py +3 -0
- scatter3d/scatter3d.py +1011 -0
- scatter3d/static/scatter3d.js +22647 -0
- scatter3d/static/scatter3d.js.map +1 -0
- scatter3d/widget_test.py +48 -0
- scatter3d_anywidget-0.1.2.dist-info/METADATA +145 -0
- scatter3d_anywidget-0.1.2.dist-info/RECORD +8 -0
- scatter3d_anywidget-0.1.2.dist-info/WHEEL +4 -0
scatter3d/scatter3d.py
ADDED
|
@@ -0,0 +1,1011 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from itertools import cycle, count
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from typing import Any, Callable, Sequence
|
|
7
|
+
import weakref
|
|
8
|
+
import base64
|
|
9
|
+
|
|
10
|
+
import anywidget
|
|
11
|
+
import traitlets
|
|
12
|
+
import numpy
|
|
13
|
+
import pandas
|
|
14
|
+
import narwhals
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
PACKAGE_DIR = Path(__file__).parent
|
|
18
|
+
JAVASCRIPT_DIR = PACKAGE_DIR / "static"
|
|
19
|
+
PROD_ESM = JAVASCRIPT_DIR / "scatter3d.js"
|
|
20
|
+
DEF_DEV_ESM = "http://127.0.0.1:5173/src/index.ts"
|
|
21
|
+
|
|
22
|
+
FLOAT_TYPE = "<f4"
|
|
23
|
+
FLOAT_TYPE_TS = "float32"
|
|
24
|
+
CATEGORY_CODES_DTYPE = "<u4" # uint32 little-endian
|
|
25
|
+
MISSING_COLOR = (0.6, 0.6, 0.6)
|
|
26
|
+
MISSING_CATEGORY_VALUE = "Unassigned"
|
|
27
|
+
|
|
28
|
+
DARK_GREY = "#111111"
|
|
29
|
+
WHITE = "#ffffff"
|
|
30
|
+
DEFAULT_POINT_SIZE = 0.15
|
|
31
|
+
DEFAULT_AXIS_LABEL_SIZE = 0.2
|
|
32
|
+
TAB20_COLORS_RGB = [
|
|
33
|
+
(0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
|
|
34
|
+
(0.6823529411764706, 0.7803921568627451, 0.9098039215686274),
|
|
35
|
+
(1.0, 0.4980392156862745, 0.054901960784313725),
|
|
36
|
+
(1.0, 0.7333333333333333, 0.47058823529411764),
|
|
37
|
+
(0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
|
|
38
|
+
(0.596078431372549, 0.8745098039215686, 0.5411764705882353),
|
|
39
|
+
(0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
|
|
40
|
+
(1.0, 0.596078431372549, 0.5882352941176471),
|
|
41
|
+
(0.5803921568627451, 0.403921568627451, 0.7411764705882353),
|
|
42
|
+
(0.7725490196078432, 0.6901960784313725, 0.8352941176470589),
|
|
43
|
+
(0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
|
|
44
|
+
(0.7686274509803922, 0.611764705882353, 0.5803921568627451),
|
|
45
|
+
(0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
|
|
46
|
+
(0.9686274509803922, 0.7137254901960784, 0.8235294117647058),
|
|
47
|
+
(0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
|
|
48
|
+
(0.7803921568627451, 0.7803921568627451, 0.7803921568627451),
|
|
49
|
+
(0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
|
|
50
|
+
(0.8588235294117647, 0.8588235294117647, 0.5529411764705883),
|
|
51
|
+
(0.09019607843137255, 0.7450980392156863, 0.8117647058823529),
|
|
52
|
+
(0.6196078431372549, 0.8549019607843137, 0.8980392156862745),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class LabelListErrorResponse(Enum):
|
|
57
|
+
ERROR = "error"
|
|
58
|
+
SET_MISSING = "missing"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _is_valid_color(color):
|
|
62
|
+
if not isinstance(color, tuple):
|
|
63
|
+
raise ValueError(f"Invalid color, should be tuples with three floats {color}")
|
|
64
|
+
if len(color) != 3:
|
|
65
|
+
raise ValueError(f"Invalid color, should be tuples with three floats {color}")
|
|
66
|
+
for value in color:
|
|
67
|
+
if value < 0 or value > 1:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Invalid color, should be coded as floats from 0 to 1 {color}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
CategoryCallback = Callable[["Category", str], None]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class Category:
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
values: narwhals.typing.IntoSeriesT,
|
|
80
|
+
label_list=None,
|
|
81
|
+
color_palette: dict[Any, tuple[float, float, float]] | None = None,
|
|
82
|
+
missing_color: tuple[float, float, float] = MISSING_COLOR,
|
|
83
|
+
):
|
|
84
|
+
self._cb_id_gen = count(1)
|
|
85
|
+
self._callbacks: dict[int, weakref.ReferenceType] = {}
|
|
86
|
+
|
|
87
|
+
self._native_values_dtype = values.dtype
|
|
88
|
+
values = narwhals.from_native(values, series_only=True)
|
|
89
|
+
self._narwhals_values_dtype = values.dtype
|
|
90
|
+
self._name = values.name
|
|
91
|
+
self._values_implementation = values.implementation
|
|
92
|
+
|
|
93
|
+
label_list = self._initialize_label_list(values, label_list)
|
|
94
|
+
|
|
95
|
+
self._label_coding = None
|
|
96
|
+
self._label_coding = self._create_label_coding(label_list)
|
|
97
|
+
|
|
98
|
+
self._encode_values(values)
|
|
99
|
+
|
|
100
|
+
self.create_color_palette(color_palette)
|
|
101
|
+
|
|
102
|
+
_is_valid_color(missing_color)
|
|
103
|
+
self._missing_color = missing_color
|
|
104
|
+
|
|
105
|
+
def subscribe(self, cb: CategoryCallback) -> int:
|
|
106
|
+
cb_id = next(self._cb_id_gen)
|
|
107
|
+
try:
|
|
108
|
+
ref = weakref.WeakMethod(cb) # bound method
|
|
109
|
+
except TypeError:
|
|
110
|
+
ref = weakref.ref(cb) # function
|
|
111
|
+
self._callbacks[cb_id] = ref
|
|
112
|
+
return cb_id
|
|
113
|
+
|
|
114
|
+
def unsubscribe(self, cb_id: int) -> None:
|
|
115
|
+
self._callbacks.pop(cb_id, None)
|
|
116
|
+
|
|
117
|
+
def _notify(self, event: str) -> None:
|
|
118
|
+
dead = []
|
|
119
|
+
for cb_id, ref in self._callbacks.items():
|
|
120
|
+
cb = ref()
|
|
121
|
+
if cb is None:
|
|
122
|
+
dead.append(cb_id)
|
|
123
|
+
else:
|
|
124
|
+
cb(self, event)
|
|
125
|
+
for cb_id in dead:
|
|
126
|
+
self._callbacks.pop(cb_id, None)
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _get_unique_labels_in_values(values):
|
|
130
|
+
return values.drop_nulls().unique().to_list()
|
|
131
|
+
|
|
132
|
+
def _initialize_label_list(self, values, label_list):
|
|
133
|
+
unique_labels = self._get_unique_labels_in_values(values)
|
|
134
|
+
if label_list is not None:
|
|
135
|
+
missing = set(unique_labels).difference(label_list)
|
|
136
|
+
if missing:
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
"To initialize the label list we need a label list to include all "
|
|
139
|
+
f"unique values, these are missing: {missing}"
|
|
140
|
+
)
|
|
141
|
+
# Keep user-provided order as-is (do not sort).
|
|
142
|
+
return list(label_list)
|
|
143
|
+
else:
|
|
144
|
+
return sorted(unique_labels)
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def _create_label_coding(label_list):
|
|
148
|
+
label_coding = OrderedDict(
|
|
149
|
+
[(label, idx) for idx, label in enumerate(label_list, start=1)]
|
|
150
|
+
)
|
|
151
|
+
return label_coding
|
|
152
|
+
|
|
153
|
+
def _encode_values(self, values):
|
|
154
|
+
coded_values = values.replace_strict(
|
|
155
|
+
self._label_coding, default=0, return_dtype=narwhals.UInt16
|
|
156
|
+
).to_numpy()
|
|
157
|
+
self._coded_values = coded_values
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def values(self):
|
|
161
|
+
coded_values = self._coded_values
|
|
162
|
+
label_coding = self._label_coding
|
|
163
|
+
if label_coding is None:
|
|
164
|
+
raise RuntimeError("label coding should be set, but it is not")
|
|
165
|
+
reverse_coding = {code: label for label, code in label_coding.items()}
|
|
166
|
+
|
|
167
|
+
if self._values_implementation == narwhals.Implementation.PANDAS:
|
|
168
|
+
if pandas.api.types.is_extension_array_dtype(self._native_values_dtype):
|
|
169
|
+
reverse_coding[0] = pandas.NA
|
|
170
|
+
else:
|
|
171
|
+
reverse_coding[0] = None
|
|
172
|
+
|
|
173
|
+
coded_values = pandas.Series(coded_values, name=self.name)
|
|
174
|
+
values = coded_values.replace(reverse_coding).astype(
|
|
175
|
+
self._native_values_dtype
|
|
176
|
+
)
|
|
177
|
+
return values
|
|
178
|
+
else:
|
|
179
|
+
coded_values = narwhals.new_series(
|
|
180
|
+
name=self.name, values=coded_values, backend=self._values_implementation
|
|
181
|
+
)
|
|
182
|
+
reverse_coding[0] = None
|
|
183
|
+
values = coded_values.replace_strict(
|
|
184
|
+
reverse_coding, return_dtype=self._narwhals_values_dtype
|
|
185
|
+
)
|
|
186
|
+
return values.to_native()
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def name(self) -> str:
|
|
190
|
+
return self._name
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def label_list(self) -> list:
|
|
194
|
+
label_coding = self._label_coding
|
|
195
|
+
if label_coding is None:
|
|
196
|
+
raise RuntimeError("label coding should be set, but it is not")
|
|
197
|
+
|
|
198
|
+
return list(label_coding.keys())
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def _get_next_unused_color(used_colors, color_cycle):
|
|
202
|
+
n_colors = len(TAB20_COLORS_RGB)
|
|
203
|
+
n_tried = 0
|
|
204
|
+
while True:
|
|
205
|
+
color = tuple(next(color_cycle))
|
|
206
|
+
n_tried += 1
|
|
207
|
+
if color not in used_colors:
|
|
208
|
+
return color
|
|
209
|
+
if n_tried >= n_colors:
|
|
210
|
+
# TAB20 exhausted: allow repeats
|
|
211
|
+
return color
|
|
212
|
+
|
|
213
|
+
def set_label_list(
|
|
214
|
+
self,
|
|
215
|
+
new_labels: list[str] | list[int],
|
|
216
|
+
on_missing_labels=LabelListErrorResponse.ERROR,
|
|
217
|
+
color_palette: dict[Any, tuple[float, float, float]] | None = None,
|
|
218
|
+
):
|
|
219
|
+
if not new_labels:
|
|
220
|
+
raise ValueError("No labels given")
|
|
221
|
+
|
|
222
|
+
if new_labels == self.label_list:
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
overrides = color_palette or {}
|
|
226
|
+
|
|
227
|
+
old_label_coding = self._label_coding
|
|
228
|
+
if old_label_coding is None:
|
|
229
|
+
raise RuntimeError(
|
|
230
|
+
"label coding should be set before trying to modify the label list"
|
|
231
|
+
)
|
|
232
|
+
labels_in_values = old_label_coding.keys()
|
|
233
|
+
|
|
234
|
+
labels_to_remove = list(set(labels_in_values).difference(new_labels))
|
|
235
|
+
if len(labels_to_remove) == len(labels_in_values):
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"None of the new labels matches the labels found in the category"
|
|
238
|
+
)
|
|
239
|
+
if on_missing_labels == LabelListErrorResponse.ERROR and labels_to_remove:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Some labels are missing from the list ({labels_to_remove}), but the action set for missing is error"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
new_label_coding = self._create_label_coding(new_labels)
|
|
245
|
+
|
|
246
|
+
# --- recode values to new codes ---
|
|
247
|
+
old_values = self._coded_values
|
|
248
|
+
new_values = numpy.full_like(self._coded_values, fill_value=0)
|
|
249
|
+
for label, new_code in new_label_coding.items():
|
|
250
|
+
if label in old_label_coding:
|
|
251
|
+
old_code = old_label_coding[label]
|
|
252
|
+
new_values[old_values == old_code] = new_code
|
|
253
|
+
self._coded_values = new_values
|
|
254
|
+
self._label_coding = new_label_coding
|
|
255
|
+
|
|
256
|
+
# --- update palette ---
|
|
257
|
+
old_palette = getattr(self, "_color_palette", {}) or {}
|
|
258
|
+
new_palette: dict[Any, tuple[float, float, float]] = {}
|
|
259
|
+
|
|
260
|
+
# pass 1: overrides > old palette
|
|
261
|
+
for label in new_labels:
|
|
262
|
+
if label in overrides:
|
|
263
|
+
color = overrides[label]
|
|
264
|
+
_is_valid_color(color)
|
|
265
|
+
new_palette[label] = tuple(color)
|
|
266
|
+
elif label in old_palette:
|
|
267
|
+
new_palette[label] = tuple(old_palette[label])
|
|
268
|
+
|
|
269
|
+
# pass 2: assign remaining labels from TAB20 cycle
|
|
270
|
+
color_cycle = cycle(TAB20_COLORS_RGB)
|
|
271
|
+
used_colors = set(new_palette.values())
|
|
272
|
+
for label in new_labels:
|
|
273
|
+
if label in new_palette:
|
|
274
|
+
continue
|
|
275
|
+
color = self._get_next_unused_color(used_colors, color_cycle)
|
|
276
|
+
used_colors.add(color)
|
|
277
|
+
new_palette[label] = color
|
|
278
|
+
|
|
279
|
+
self._color_palette = new_palette
|
|
280
|
+
|
|
281
|
+
self._notify("label_list")
|
|
282
|
+
self._notify("palette")
|
|
283
|
+
|
|
284
|
+
def set_coded_values(
|
|
285
|
+
self,
|
|
286
|
+
coded_values: numpy.ndarray,
|
|
287
|
+
label_list: list[str] | list[int],
|
|
288
|
+
skip_copying_array=False,
|
|
289
|
+
):
|
|
290
|
+
if not label_list == self.label_list:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
"The label list used to code the new values should match the current one"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
label_encoding = self._create_label_coding(label_list)
|
|
296
|
+
if self._label_coding != label_encoding:
|
|
297
|
+
raise RuntimeError("The new label encoding wouldn't match the old one")
|
|
298
|
+
|
|
299
|
+
old_coded_values = self.coded_values
|
|
300
|
+
if old_coded_values.shape != coded_values.shape:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
"The new coded values array has a different size than the older one"
|
|
303
|
+
)
|
|
304
|
+
if old_coded_values.dtype != coded_values.dtype:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
"The dtype of the new coding values does not match the one of the old ones"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
if not skip_copying_array:
|
|
310
|
+
coded_values = coded_values.copy(order="K")
|
|
311
|
+
|
|
312
|
+
self._coded_values = coded_values
|
|
313
|
+
self._notify("coded_values")
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def coded_values(self):
|
|
317
|
+
return self._coded_values
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def label_coding(self):
|
|
321
|
+
label_coding = self._label_coding
|
|
322
|
+
if label_coding is None:
|
|
323
|
+
raise RuntimeError(
|
|
324
|
+
"label coding should be set before trying to modify the label list"
|
|
325
|
+
)
|
|
326
|
+
return [(label, code) for label, code in label_coding.items()]
|
|
327
|
+
|
|
328
|
+
def create_color_palette(
|
|
329
|
+
self, color_palette: dict[Any, tuple[float, float, float]] | None = None
|
|
330
|
+
):
|
|
331
|
+
default_colors = cycle(TAB20_COLORS_RGB)
|
|
332
|
+
|
|
333
|
+
palette = {}
|
|
334
|
+
for label in self.label_list:
|
|
335
|
+
if color_palette:
|
|
336
|
+
try:
|
|
337
|
+
color = color_palette[label]
|
|
338
|
+
_is_valid_color(color)
|
|
339
|
+
except KeyError:
|
|
340
|
+
raise KeyError(
|
|
341
|
+
f"Color palette given, but color missing for label: {label}"
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
color = next(default_colors)
|
|
345
|
+
palette[label] = tuple(color)
|
|
346
|
+
self._color_palette = palette
|
|
347
|
+
self._notify("palette")
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def color_palette(self):
|
|
351
|
+
return self._color_palette.copy()
|
|
352
|
+
|
|
353
|
+
@property
|
|
354
|
+
def color_palette_for_codes(self):
|
|
355
|
+
palette = self.color_palette
|
|
356
|
+
|
|
357
|
+
return {code: palette[label] for label, code in self.label_coding}
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
def missing_color(self):
|
|
361
|
+
return self._missing_color
|
|
362
|
+
|
|
363
|
+
@property
|
|
364
|
+
def num_values(self):
|
|
365
|
+
return self.coded_values.size
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def num_unassigned(self) -> int:
|
|
369
|
+
"""Number of values unassigned / missing."""
|
|
370
|
+
coded = self._coded_values
|
|
371
|
+
return int(numpy.count_nonzero(coded == 0))
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _esm_source() -> str | Path:
|
|
375
|
+
if os.environ.get("ANY_SCATTER3D_DEV", ""):
|
|
376
|
+
return os.environ.get("ANY_SCATTER3D_DEV_URL", DEF_DEV_ESM)
|
|
377
|
+
return PROD_ESM
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _is_missing(value: object) -> bool:
|
|
381
|
+
if value is None or value is pandas.NA:
|
|
382
|
+
return True
|
|
383
|
+
try:
|
|
384
|
+
# True for float('nan') and pandas NA scalars
|
|
385
|
+
return bool(pandas.isna(value))
|
|
386
|
+
except Exception:
|
|
387
|
+
return False
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class Scatter3dWidget(anywidget.AnyWidget):
|
|
391
|
+
_esm = _esm_source()
|
|
392
|
+
|
|
393
|
+
# xyz coords for the points
|
|
394
|
+
# Packed float32 array of shape (N, 3), row-major.
|
|
395
|
+
# TS interprets as Float32Array with length 3*N.
|
|
396
|
+
xyz_bytes_t = traitlets.Bytes(
|
|
397
|
+
default_value=b"",
|
|
398
|
+
help="Packed float32 Nx3, row-major.",
|
|
399
|
+
).tag(sync=True)
|
|
400
|
+
|
|
401
|
+
# Packed uint16 array of length N.
|
|
402
|
+
# Code 0 means "missing / unassigned".
|
|
403
|
+
# Codes 1..K correspond to labels_t[0..K-1].
|
|
404
|
+
coded_values_t = traitlets.Bytes(
|
|
405
|
+
default_value=b"",
|
|
406
|
+
help="Packed uint16 length N. 0=missing, 1..K correspond to labels_t.",
|
|
407
|
+
).tag(sync=True)
|
|
408
|
+
|
|
409
|
+
# List[str] of length K, stable ordering.
|
|
410
|
+
# labels_t[i] corresponds to code (i+1).
|
|
411
|
+
labels_t = traitlets.List(
|
|
412
|
+
traitlets.Unicode(),
|
|
413
|
+
default_value=[],
|
|
414
|
+
help="Label list (length K), where code = index+1.",
|
|
415
|
+
).tag(sync=True)
|
|
416
|
+
|
|
417
|
+
# List[[r,g,b]] of length K, aligned with labels_t.
|
|
418
|
+
# Each component is float in [0,1].
|
|
419
|
+
colors_t = traitlets.List(
|
|
420
|
+
traitlets.List(traitlets.Float(), minlen=3, maxlen=3),
|
|
421
|
+
default_value=[],
|
|
422
|
+
help="Per-label RGB colors (length K) aligned with labels_t; floats in [0,1].",
|
|
423
|
+
).tag(sync=True)
|
|
424
|
+
|
|
425
|
+
# [r,g,b] used when coded value is 0 or otherwise missing
|
|
426
|
+
missing_color_t = traitlets.List(
|
|
427
|
+
traitlets.Float(),
|
|
428
|
+
default_value=[0.6, 0.6, 0.6],
|
|
429
|
+
minlen=3,
|
|
430
|
+
maxlen=3,
|
|
431
|
+
help="RGB color for missing/unassigned (code 0).",
|
|
432
|
+
).tag(sync=True)
|
|
433
|
+
|
|
434
|
+
# --- lasso round-trip channels ---
|
|
435
|
+
# Dict message TS -> Python describing a committed lasso operation.
|
|
436
|
+
lasso_request_t = traitlets.Dict(default_value={}).tag(sync=True)
|
|
437
|
+
# Packed uint8 bitmask (binary).
|
|
438
|
+
lasso_mask_t = traitlets.Bytes(default_value=b"").tag(sync=True)
|
|
439
|
+
# Dict message Python -> TS acknowledging the last request (ok/error).
|
|
440
|
+
lasso_result_t = traitlets.Dict(default_value={}).tag(sync=True)
|
|
441
|
+
|
|
442
|
+
point_size_t = traitlets.Float(
|
|
443
|
+
default_value=DEFAULT_POINT_SIZE,
|
|
444
|
+
help="Point size for rendering (three.js PointsMaterial.size).",
|
|
445
|
+
).tag(sync=True)
|
|
446
|
+
|
|
447
|
+
axis_label_size_t = traitlets.Float(
|
|
448
|
+
default_value=DEFAULT_AXIS_LABEL_SIZE,
|
|
449
|
+
).tag(sync=True)
|
|
450
|
+
|
|
451
|
+
show_axes_t = traitlets.Bool(
|
|
452
|
+
default_value=True,
|
|
453
|
+
help=("Whether to draw axis lines (X, Y, Z) from the origin (0,0,0)."),
|
|
454
|
+
).tag(sync=True)
|
|
455
|
+
|
|
456
|
+
tooltip_request_t = traitlets.Dict(default_value={}).tag(sync=True)
|
|
457
|
+
tooltip_response_t = traitlets.Dict(default_value={}).tag(sync=True)
|
|
458
|
+
|
|
459
|
+
client_ready_t = traitlets.Bool(
|
|
460
|
+
default_value=False,
|
|
461
|
+
help="Set True by the frontend once JS is initialized and can talk to Python.",
|
|
462
|
+
).tag(sync=True)
|
|
463
|
+
|
|
464
|
+
interactive_ready_t = traitlets.Bool(
|
|
465
|
+
default_value=False,
|
|
466
|
+
help="True once frontend has announced readiness; used to gate UI features.",
|
|
467
|
+
).tag(sync=True)
|
|
468
|
+
|
|
469
|
+
interaction_mode_t = traitlets.Unicode(
|
|
470
|
+
default_value="rotate",
|
|
471
|
+
help="Interaction mode: 'rotate' or 'lasso'.",
|
|
472
|
+
).tag(sync=True)
|
|
473
|
+
|
|
474
|
+
# Active category label (string). In lasso mode it must always be non-empty and valid.
|
|
475
|
+
active_category_t = traitlets.Unicode(
|
|
476
|
+
default_value=None,
|
|
477
|
+
allow_none=True,
|
|
478
|
+
help="Active category label (str) or None. In lasso mode it must be a valid label from labels_t.",
|
|
479
|
+
).tag(sync=True)
|
|
480
|
+
|
|
481
|
+
# Legend placement (split into two validated strings; easier to validate than a Dict schema)
|
|
482
|
+
legend_side_t = traitlets.Unicode(
|
|
483
|
+
default_value="right",
|
|
484
|
+
help="Legend side: 'left' or 'right'.",
|
|
485
|
+
).tag(sync=True)
|
|
486
|
+
|
|
487
|
+
legend_dock_t = traitlets.Unicode(
|
|
488
|
+
default_value="top",
|
|
489
|
+
help="Legend dock: 'top' or 'bottom'.",
|
|
490
|
+
).tag(sync=True)
|
|
491
|
+
|
|
492
|
+
def __init__(
|
|
493
|
+
self,
|
|
494
|
+
xyz: numpy.ndarray,
|
|
495
|
+
category: Category,
|
|
496
|
+
point_ids: Sequence[str]
|
|
497
|
+
| Sequence[int]
|
|
498
|
+
| narwhals.typing.IntoSeriesT
|
|
499
|
+
| None = None,
|
|
500
|
+
):
|
|
501
|
+
super().__init__()
|
|
502
|
+
self._category_cb_id: int | None = None
|
|
503
|
+
|
|
504
|
+
if category is not None and xyz.shape[0] != category.num_values:
|
|
505
|
+
raise ValueError(
|
|
506
|
+
f"The number of points ({xyz.shape[0]}) should match "
|
|
507
|
+
f"the number of values in the category: {category.num_values}"
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
if point_ids is not None and xyz.shape[0] != len(point_ids):
|
|
511
|
+
raise ValueError(
|
|
512
|
+
f"The number of points ({xyz.shape[0]}) should match "
|
|
513
|
+
f"the number of values in the category: {category.num_values}"
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Keep a stable callback object so unsubscribe works.
|
|
517
|
+
self._category_cb = self._on_category_changed
|
|
518
|
+
|
|
519
|
+
self._xyz = None
|
|
520
|
+
self._category = None
|
|
521
|
+
self.xyz = xyz
|
|
522
|
+
|
|
523
|
+
self._set_default_sizes()
|
|
524
|
+
|
|
525
|
+
self.category = category
|
|
526
|
+
|
|
527
|
+
self.point_ids = self._normalize_point_ids(point_ids)
|
|
528
|
+
|
|
529
|
+
# clear tooltip state
|
|
530
|
+
self.tooltip_response_t = {}
|
|
531
|
+
|
|
532
|
+
# Enforce initial invariants (rotate default allows empty active category)
|
|
533
|
+
self._ensure_active_category_invariants()
|
|
534
|
+
|
|
535
|
+
@traitlets.validate("interaction_mode_t")
|
|
536
|
+
def _validate_interaction_mode_t(self, proposal):
|
|
537
|
+
v = proposal["value"]
|
|
538
|
+
if v not in ("rotate", "lasso"):
|
|
539
|
+
raise traitlets.TraitError("interaction_mode_t must be 'rotate' or 'lasso'")
|
|
540
|
+
return v
|
|
541
|
+
|
|
542
|
+
@traitlets.validate("legend_side_t")
|
|
543
|
+
def _validate_legend_side_t(self, proposal):
|
|
544
|
+
v = proposal["value"]
|
|
545
|
+
if v not in ("left", "right"):
|
|
546
|
+
raise traitlets.TraitError("legend_side_t must be 'left' or 'right'")
|
|
547
|
+
return v
|
|
548
|
+
|
|
549
|
+
@traitlets.validate("legend_dock_t")
|
|
550
|
+
def _validate_legend_dock_t(self, proposal):
|
|
551
|
+
v = proposal["value"]
|
|
552
|
+
if v not in ("top", "bottom"):
|
|
553
|
+
raise traitlets.TraitError("legend_dock_t must be 'top' or 'bottom'")
|
|
554
|
+
return v
|
|
555
|
+
|
|
556
|
+
def _ensure_active_category_invariants(self) -> None:
|
|
557
|
+
"""
|
|
558
|
+
Enforce invariants for interaction_mode_t and active_category_t.
|
|
559
|
+
- In lasso mode, active_category_t must be a valid label and non-empty.
|
|
560
|
+
If empty/invalid, set deterministically to first label in labels_t.
|
|
561
|
+
- In rotate mode, empty active_category_t is allowed.
|
|
562
|
+
If non-empty but invalid, raise (no silent fallback).
|
|
563
|
+
"""
|
|
564
|
+
mode = self.interaction_mode_t
|
|
565
|
+
labels = list(self.labels_t or [])
|
|
566
|
+
|
|
567
|
+
if mode == "lasso":
|
|
568
|
+
if not labels:
|
|
569
|
+
# Professional behavior: lasso mode without categories is a hard error.
|
|
570
|
+
raise RuntimeError("Cannot enter lasso mode: labels_t is empty")
|
|
571
|
+
|
|
572
|
+
if self.active_category_t is not None and self.active_category_t in labels:
|
|
573
|
+
return
|
|
574
|
+
|
|
575
|
+
# Required deterministic behavior: choose the first category label.
|
|
576
|
+
self.active_category_t = labels[0]
|
|
577
|
+
if self.interactive_ready_t:
|
|
578
|
+
self.send_state("active_category_t")
|
|
579
|
+
return
|
|
580
|
+
|
|
581
|
+
# rotate mode
|
|
582
|
+
if self.active_category_t is None:
|
|
583
|
+
return
|
|
584
|
+
if self.active_category_t not in labels:
|
|
585
|
+
raise RuntimeError(
|
|
586
|
+
f"active_category_t={self.active_category_t!r} is not present in labels_t"
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
@traitlets.observe("interaction_mode_t")
|
|
590
|
+
def _on_interaction_mode_t(self, change) -> None:
|
|
591
|
+
if change.get("new") == "lasso":
|
|
592
|
+
self._ensure_active_category_invariants()
|
|
593
|
+
|
|
594
|
+
@traitlets.observe("active_category_t")
|
|
595
|
+
def _on_active_category_t(self, change) -> None:
|
|
596
|
+
if self.interaction_mode_t == "lasso" and change.get("new") is None:
|
|
597
|
+
# NOTE: traitlets bypasses @validate(...) when allow_none=True.
|
|
598
|
+
# In lasso mode, clearing is a hard error (no silent fallback).
|
|
599
|
+
old = change.get("old")
|
|
600
|
+
if old is not None:
|
|
601
|
+
# restore previous valid value to keep state consistent
|
|
602
|
+
self.set_trait("active_category_t", old)
|
|
603
|
+
raise traitlets.TraitError("active_category_t cannot be None in lasso mode")
|
|
604
|
+
|
|
605
|
+
@traitlets.observe("labels_t")
|
|
606
|
+
def _on_labels_t(self, change) -> None:
|
|
607
|
+
if self.interaction_mode_t == "lasso":
|
|
608
|
+
self._ensure_active_category_invariants()
|
|
609
|
+
elif self.active_category_t is not None and self.active_category_t not in (
|
|
610
|
+
change.get("new") or []
|
|
611
|
+
):
|
|
612
|
+
# in rotate mode, invalid active is a real error
|
|
613
|
+
raise RuntimeError(...)
|
|
614
|
+
|
|
615
|
+
@traitlets.observe("client_ready_t")
|
|
616
|
+
def _on_client_ready_t(self, change) -> None:
|
|
617
|
+
# Only ever transition False -> True
|
|
618
|
+
if bool(change.get("new")) is True and self.interactive_ready_t is not True:
|
|
619
|
+
self.interactive_ready_t = True
|
|
620
|
+
self.send_state("interactive_ready_t")
|
|
621
|
+
|
|
622
|
+
def _set_default_sizes(self):
|
|
623
|
+
# If there are no points, keep deterministic defaults.
|
|
624
|
+
# This must never crash widget construction.
|
|
625
|
+
if self.num_points == 0:
|
|
626
|
+
self.point_size_t = float(DEFAULT_POINT_SIZE)
|
|
627
|
+
self.axis_label_size_t = float(DEFAULT_AXIS_LABEL_SIZE)
|
|
628
|
+
return
|
|
629
|
+
|
|
630
|
+
max_abs = float(numpy.abs(self.xyz).max())
|
|
631
|
+
self.point_size_t = max_abs / 20.0
|
|
632
|
+
self.axis_label_size_t = max_abs / 5.0
|
|
633
|
+
|
|
634
|
+
def _normalize_point_ids(self, point_ids):
|
|
635
|
+
num_points = self.num_points
|
|
636
|
+
|
|
637
|
+
if point_ids is None:
|
|
638
|
+
point_ids = tuple(range(1, num_points + 1))
|
|
639
|
+
|
|
640
|
+
elif isinstance(point_ids, narwhals.Series):
|
|
641
|
+
point_ids = tuple(point_ids.to_list())
|
|
642
|
+
|
|
643
|
+
elif isinstance(point_ids, Sequence) and not isinstance(
|
|
644
|
+
point_ids, (str, bytes)
|
|
645
|
+
):
|
|
646
|
+
point_ids = tuple(point_ids)
|
|
647
|
+
else:
|
|
648
|
+
raise TypeError("point_ids must be a Series, a sequence of values, or None")
|
|
649
|
+
|
|
650
|
+
if len(point_ids) != num_points:
|
|
651
|
+
raise ValueError("point_ids length must match number of points")
|
|
652
|
+
|
|
653
|
+
return point_ids
|
|
654
|
+
|
|
655
|
+
def _category_label_for_index(self, idx: int) -> str | None:
|
|
656
|
+
if self._category is None:
|
|
657
|
+
return None
|
|
658
|
+
coded = self._category.coded_values
|
|
659
|
+
code = int(coded[idx])
|
|
660
|
+
if code <= 0:
|
|
661
|
+
return None
|
|
662
|
+
# label_list is 0-based, codes are 1..K
|
|
663
|
+
return str(self._category.label_list[code - 1])
|
|
664
|
+
|
|
665
|
+
@traitlets.validate("active_category_t")
|
|
666
|
+
def _validate_active_category_t(self, proposal):
|
|
667
|
+
v = proposal["value"]
|
|
668
|
+
|
|
669
|
+
if v is None:
|
|
670
|
+
# allowed only in rotate mode
|
|
671
|
+
mode = getattr(self, "interaction_mode_t", "rotate")
|
|
672
|
+
if mode == "lasso":
|
|
673
|
+
raise traitlets.TraitError(
|
|
674
|
+
"active_category_t cannot be None in lasso mode"
|
|
675
|
+
)
|
|
676
|
+
return None
|
|
677
|
+
|
|
678
|
+
if not isinstance(v, str):
|
|
679
|
+
raise traitlets.TraitError("active_category_t must be a string or None")
|
|
680
|
+
|
|
681
|
+
if v == "":
|
|
682
|
+
# reject empty string entirely; it was legacy sentinel
|
|
683
|
+
mode = getattr(self, "interaction_mode_t", "rotate")
|
|
684
|
+
raise traitlets.TraitError(
|
|
685
|
+
"active_category_t must be None (rotate) or a valid label string (rotate/lasso); empty string is not allowed"
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
labels = list(self.labels_t or [])
|
|
689
|
+
mode = getattr(self, "interaction_mode_t", "rotate")
|
|
690
|
+
|
|
691
|
+
if mode == "lasso":
|
|
692
|
+
if not labels:
|
|
693
|
+
raise traitlets.TraitError(
|
|
694
|
+
"Cannot set active_category_t: labels_t is empty"
|
|
695
|
+
)
|
|
696
|
+
if v not in labels:
|
|
697
|
+
raise traitlets.TraitError(
|
|
698
|
+
f"active_category_t={v!r} is not present in labels_t"
|
|
699
|
+
)
|
|
700
|
+
return v
|
|
701
|
+
|
|
702
|
+
# rotate mode
|
|
703
|
+
if labels and v not in labels:
|
|
704
|
+
raise traitlets.TraitError(
|
|
705
|
+
f"active_category_t={v!r} is not present in labels_t"
|
|
706
|
+
)
|
|
707
|
+
return v
|
|
708
|
+
|
|
709
|
+
@traitlets.observe("tooltip_request_t")
|
|
710
|
+
def _on_tooltip_request(self, change) -> None:
|
|
711
|
+
req = change["new"] or {}
|
|
712
|
+
if not isinstance(req, dict):
|
|
713
|
+
return
|
|
714
|
+
if req.get("kind") != "tooltip":
|
|
715
|
+
return
|
|
716
|
+
|
|
717
|
+
request_id = int(req.get("request_id", 0) or 0)
|
|
718
|
+
i = req.get("i", None)
|
|
719
|
+
|
|
720
|
+
try:
|
|
721
|
+
i = int(i)
|
|
722
|
+
if i < 0 or i >= self.num_points:
|
|
723
|
+
raise IndexError(f"point index out of range: {i}")
|
|
724
|
+
|
|
725
|
+
data = {
|
|
726
|
+
"idx": i,
|
|
727
|
+
"id": str(self.point_ids[i]),
|
|
728
|
+
"category": self._category_label_for_index(i),
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
self.tooltip_response_t = {
|
|
732
|
+
"request_id": request_id,
|
|
733
|
+
"status": "ok",
|
|
734
|
+
"data": data,
|
|
735
|
+
}
|
|
736
|
+
except Exception as e:
|
|
737
|
+
self.tooltip_response_t = {
|
|
738
|
+
"request_id": request_id,
|
|
739
|
+
"status": "error",
|
|
740
|
+
"message": str(e),
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
# Force comm sync (important for anywidget)
|
|
744
|
+
self.send_state("tooltip_response_t")
|
|
745
|
+
|
|
746
|
+
def _on_category_changed(self, category: Category, event: str) -> None:
|
|
747
|
+
"""
|
|
748
|
+
Called when Category mutates.
|
|
749
|
+
"""
|
|
750
|
+
# Sanity: ignore stale callbacks (if category replaced)
|
|
751
|
+
if category is not self._category:
|
|
752
|
+
return
|
|
753
|
+
self._sync_traitlets_from_category()
|
|
754
|
+
|
|
755
|
+
@staticmethod
|
|
756
|
+
def _pack_xyz_float32_c(xyz: numpy.ndarray) -> tuple[numpy.ndarray, bytes]:
|
|
757
|
+
"""
|
|
758
|
+
Return (xyz_float32_c, packed_bytes).
|
|
759
|
+
- xyz_float32_c: float32, C-contiguous, shape (N,3)
|
|
760
|
+
- packed_bytes: xyz_float32_c.tobytes(order="C")
|
|
761
|
+
"""
|
|
762
|
+
if not isinstance(xyz, numpy.ndarray):
|
|
763
|
+
raise ValueError("xyz should be a numpy array")
|
|
764
|
+
|
|
765
|
+
if xyz.ndim != 2 or xyz.shape[1] != 3:
|
|
766
|
+
raise ValueError("xyz should have shape (N, 3)")
|
|
767
|
+
|
|
768
|
+
# Convert dtype to float32 (TS expects Float32Array)
|
|
769
|
+
# Ensure row-major contiguous layout for stable tobytes.
|
|
770
|
+
xyz_f32 = numpy.asarray(xyz, dtype=numpy.float32, order="C")
|
|
771
|
+
if not xyz_f32.flags["C_CONTIGUOUS"]:
|
|
772
|
+
xyz_f32 = numpy.ascontiguousarray(xyz_f32)
|
|
773
|
+
|
|
774
|
+
# Remap (x,y,z) -> (x,z,y) so that "z" in user data becomes "up" (Y) in Three.js
|
|
775
|
+
# avoid mutating caller's array if it was already float32 C
|
|
776
|
+
# xyz_f32 = xyz_f32.copy()
|
|
777
|
+
xyz_f32[:, [1, 2]] = xyz_f32[:, [2, 1]]
|
|
778
|
+
|
|
779
|
+
return xyz_f32, xyz_f32.tobytes(order="C")
|
|
780
|
+
|
|
781
|
+
def _get_xyz(self) -> numpy.ndarray:
|
|
782
|
+
if self._xyz is None:
|
|
783
|
+
raise RuntimeError("xyz has not been set")
|
|
784
|
+
out = self._xyz.copy()
|
|
785
|
+
out[:, [1, 2]] = out[:, [2, 1]]
|
|
786
|
+
return out
|
|
787
|
+
|
|
788
|
+
def _set_xyz(self, xyz: numpy.ndarray) -> None:
|
|
789
|
+
xyz_f32, xyz_bytes = self._pack_xyz_float32_c(xyz)
|
|
790
|
+
|
|
791
|
+
# If category already set, enforce N consistency
|
|
792
|
+
if self._category is not None and xyz_f32.shape[0] != self.category.num_values:
|
|
793
|
+
raise ValueError(
|
|
794
|
+
f"The number of points ({xyz_f32.shape[0]}) should match "
|
|
795
|
+
f"the number of values in the category: {self.category.num_values}"
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
self._xyz = xyz_f32
|
|
799
|
+
self.xyz_bytes_t = xyz_bytes
|
|
800
|
+
|
|
801
|
+
xyz = property(_get_xyz, _set_xyz)
|
|
802
|
+
|
|
803
|
+
@staticmethod
|
|
804
|
+
def _pack_u16_c(arr: numpy.ndarray) -> bytes:
|
|
805
|
+
arr_u16 = numpy.asarray(arr, dtype=numpy.uint16, order="C")
|
|
806
|
+
if not arr_u16.flags["C_CONTIGUOUS"]:
|
|
807
|
+
arr_u16 = numpy.ascontiguousarray(arr_u16)
|
|
808
|
+
return arr_u16.tobytes(order="C")
|
|
809
|
+
|
|
810
|
+
def _sync_traitlets_from_category(self) -> None:
|
|
811
|
+
"""
|
|
812
|
+
Push the Category state into synced transport traitlets.
|
|
813
|
+
Assumes self._xyz and self._category are both set and consistent in length.
|
|
814
|
+
"""
|
|
815
|
+
if self._category is None:
|
|
816
|
+
raise RuntimeError("The category should be set")
|
|
817
|
+
|
|
818
|
+
cat = self._category
|
|
819
|
+
|
|
820
|
+
# labels_t must be JSON-friendly; enforce str
|
|
821
|
+
labels = [str(lbl) for lbl in cat.label_list]
|
|
822
|
+
self.labels_t = labels
|
|
823
|
+
|
|
824
|
+
# coded values: uint16 bytes, length N
|
|
825
|
+
coded = cat.coded_values
|
|
826
|
+
if coded.shape[0] != self.num_points:
|
|
827
|
+
raise RuntimeError(
|
|
828
|
+
f"Category has {coded.shape[0]} values but xyz has {self.num_points} points"
|
|
829
|
+
)
|
|
830
|
+
self.coded_values_t = self._pack_u16_c(coded)
|
|
831
|
+
|
|
832
|
+
# colors aligned with labels order
|
|
833
|
+
# Category stores palette keyed by original labels; we reconstruct in label_list order.
|
|
834
|
+
palette = cat.color_palette # label -> (r,g,b)
|
|
835
|
+
self.colors_t = [list(map(float, palette[lbl])) for lbl in cat.label_list]
|
|
836
|
+
|
|
837
|
+
# missing color
|
|
838
|
+
self.missing_color_t = list(map(float, cat.missing_color))
|
|
839
|
+
|
|
840
|
+
if len(self.colors_t) != len(self.labels_t):
|
|
841
|
+
raise RuntimeError(
|
|
842
|
+
"Internal error: colors_t length must match labels_t length"
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
self._ensure_active_category_invariants()
|
|
846
|
+
|
|
847
|
+
def _get_category(self):
|
|
848
|
+
return self._category
|
|
849
|
+
|
|
850
|
+
def _set_category(self, category: Category) -> None:
|
|
851
|
+
if self._xyz is not None and category.num_values != self.num_points:
|
|
852
|
+
raise ValueError(
|
|
853
|
+
f"The number of values in the category ({category.num_values}) "
|
|
854
|
+
f"should match the number of points {self.num_points}"
|
|
855
|
+
)
|
|
856
|
+
if self._category is not None and self._category_cb_id is not None:
|
|
857
|
+
self._category.unsubscribe(self._category_cb_id)
|
|
858
|
+
|
|
859
|
+
self._category = category
|
|
860
|
+
# Subscribe to new category
|
|
861
|
+
self._category_cb_id = category.subscribe(self._on_category_changed)
|
|
862
|
+
self._sync_traitlets_from_category()
|
|
863
|
+
|
|
864
|
+
category = property(_get_category, _set_category)
|
|
865
|
+
|
|
866
|
+
@property
|
|
867
|
+
def num_points(self):
|
|
868
|
+
return self.xyz.shape[0]
|
|
869
|
+
|
|
870
|
+
def close(self):
|
|
871
|
+
# detach callback to avoid keeping references around.
|
|
872
|
+
if self._category is not None and self._category_cb_id is not None:
|
|
873
|
+
self._category.unsubscribe(self._category_cb_id)
|
|
874
|
+
self._category_cb_id = None
|
|
875
|
+
super().close()
|
|
876
|
+
|
|
877
|
+
def _label_to_code_map(self) -> dict[str, int]:
|
|
878
|
+
# labels_t[i] -> code i+1
|
|
879
|
+
return {lbl: i + 1 for i, lbl in enumerate(self.labels_t)}
|
|
880
|
+
|
|
881
|
+
def _unpack_mask(self, mask_payload) -> numpy.ndarray:
|
|
882
|
+
n = self.num_points
|
|
883
|
+
needed = (n + 7) // 8
|
|
884
|
+
|
|
885
|
+
if isinstance(mask_payload, (bytes, bytearray, memoryview)):
|
|
886
|
+
mask_bytes = bytes(mask_payload)
|
|
887
|
+
else:
|
|
888
|
+
raise ValueError(
|
|
889
|
+
f"lasso_mask_t must be bytes-like, got {type(mask_payload)}"
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
if len(mask_bytes) < needed:
|
|
893
|
+
raise ValueError(
|
|
894
|
+
f"lasso_mask_t too short: got {len(mask_bytes)} bytes, need {needed} for N={n}"
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
b = numpy.frombuffer(mask_bytes, dtype=numpy.uint8, count=needed)
|
|
898
|
+
bits = numpy.unpackbits(b, bitorder="big")
|
|
899
|
+
return bits[:n].astype(bool, copy=False)
|
|
900
|
+
|
|
901
|
+
def _apply_lasso_mask_edit(self, op: str, code: int, mask: numpy.ndarray) -> int:
|
|
902
|
+
"""
|
|
903
|
+
Apply add/remove using a boolean mask of length N.
|
|
904
|
+
Returns number of points actually changed.
|
|
905
|
+
"""
|
|
906
|
+
if self._category is None:
|
|
907
|
+
raise RuntimeError("No category set")
|
|
908
|
+
|
|
909
|
+
if mask.dtype != numpy.bool_ or mask.shape != (self.num_points,):
|
|
910
|
+
raise ValueError("Internal error: mask must be bool with shape (N,)")
|
|
911
|
+
|
|
912
|
+
if code < 0 or code > 65535:
|
|
913
|
+
raise ValueError(f"Invalid code {code} (must fit uint16)")
|
|
914
|
+
if code == 0 and op == "add":
|
|
915
|
+
raise ValueError("Cannot add code 0 (reserved for missing/unassigned)")
|
|
916
|
+
|
|
917
|
+
old = self._category.coded_values
|
|
918
|
+
new = old.copy()
|
|
919
|
+
|
|
920
|
+
if op == "add":
|
|
921
|
+
changed = int(numpy.sum(new[mask] != numpy.uint16(code)))
|
|
922
|
+
new[mask] = numpy.uint16(code)
|
|
923
|
+
elif op == "remove":
|
|
924
|
+
# Only remove points currently in that label
|
|
925
|
+
to_zero = mask & (new == numpy.uint16(code))
|
|
926
|
+
changed = int(numpy.sum(to_zero))
|
|
927
|
+
new[to_zero] = numpy.uint16(0)
|
|
928
|
+
else:
|
|
929
|
+
raise ValueError(f"Unknown op: {op!r}")
|
|
930
|
+
|
|
931
|
+
# Update Category (will notify; widget callback syncs coded_values_t etc.)
|
|
932
|
+
self._category.set_coded_values(
|
|
933
|
+
coded_values=new,
|
|
934
|
+
label_list=self._category.label_list,
|
|
935
|
+
skip_copying_array=True,
|
|
936
|
+
)
|
|
937
|
+
return changed
|
|
938
|
+
|
|
939
|
+
@traitlets.observe("lasso_request_t")
|
|
940
|
+
def _on_lasso_request_t(self, change) -> None:
|
|
941
|
+
req = change.get("new", {})
|
|
942
|
+
if not req:
|
|
943
|
+
return
|
|
944
|
+
|
|
945
|
+
request_id = req.get("request_id")
|
|
946
|
+
res: dict[str, object] = {"request_id": request_id}
|
|
947
|
+
|
|
948
|
+
try:
|
|
949
|
+
if req.get("kind") != "lasso_commit":
|
|
950
|
+
raise ValueError(f"Unsupported kind: {req.get('kind')!r}")
|
|
951
|
+
|
|
952
|
+
op = req.get("op")
|
|
953
|
+
if op not in ("add", "remove"):
|
|
954
|
+
raise ValueError(f"Invalid op: {op!r}")
|
|
955
|
+
|
|
956
|
+
# resolve code from either explicit code or label
|
|
957
|
+
if "code" in req and req["code"] is not None:
|
|
958
|
+
code = int(req["code"])
|
|
959
|
+
else:
|
|
960
|
+
label = req.get("label")
|
|
961
|
+
if label is None:
|
|
962
|
+
raise ValueError("Missing field: label (or code)")
|
|
963
|
+
label_s = str(label)
|
|
964
|
+
m = self._label_to_code_map()
|
|
965
|
+
if label_s not in m:
|
|
966
|
+
raise ValueError(f"Unknown label: {label_s!r}")
|
|
967
|
+
code = m[label_s]
|
|
968
|
+
|
|
969
|
+
# unpack mask from bytes traitlet
|
|
970
|
+
if not isinstance(self.lasso_mask_t, (bytes, bytearray, memoryview)):
|
|
971
|
+
raise RuntimeError(
|
|
972
|
+
f"Internal error: lasso_mask_t must be bytes, got {type(self.lasso_mask_t)}"
|
|
973
|
+
)
|
|
974
|
+
mask = self._unpack_mask(self.lasso_mask_t)
|
|
975
|
+
num_selected = int(numpy.sum(mask))
|
|
976
|
+
|
|
977
|
+
changed = self._apply_lasso_mask_edit(op=op, code=code, mask=mask)
|
|
978
|
+
|
|
979
|
+
res.update(
|
|
980
|
+
{
|
|
981
|
+
"status": "ok",
|
|
982
|
+
"num_selected": num_selected,
|
|
983
|
+
"num_changed": changed,
|
|
984
|
+
}
|
|
985
|
+
)
|
|
986
|
+
except Exception as e:
|
|
987
|
+
res.update({"status": "error", "message": str(e)})
|
|
988
|
+
|
|
989
|
+
self.lasso_result_t = res
|
|
990
|
+
|
|
991
|
+
def _get_point_size(self) -> float:
|
|
992
|
+
return float(self.point_size_t)
|
|
993
|
+
|
|
994
|
+
def _set_point_size(self, value: float) -> None:
|
|
995
|
+
v = float(value)
|
|
996
|
+
if not numpy.isfinite(v) or v <= 0:
|
|
997
|
+
raise ValueError("point_size must be a finite positive number")
|
|
998
|
+
self.point_size_t = v
|
|
999
|
+
|
|
1000
|
+
point_size = property(_get_point_size, _set_point_size)
|
|
1001
|
+
|
|
1002
|
+
def _get_axis_label_size(self) -> float:
|
|
1003
|
+
return float(self.axis_label_size_t)
|
|
1004
|
+
|
|
1005
|
+
def _set_axis_label_size(self, value: float) -> None:
|
|
1006
|
+
v = float(value)
|
|
1007
|
+
if not numpy.isfinite(v) or v <= 0:
|
|
1008
|
+
raise ValueError("axes label size must be a finite positive number")
|
|
1009
|
+
self.axis_label_size_t = v
|
|
1010
|
+
|
|
1011
|
+
axis_label_size = property(_get_axis_label_size, _set_axis_label_size)
|