lazylabel-gui 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lazylabel/main.py ADDED
@@ -0,0 +1,864 @@
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ import qdarktheme
5
+ import cv2
6
+ from PyQt6.QtWidgets import (
7
+ QApplication,
8
+ QMainWindow,
9
+ QWidget,
10
+ QHBoxLayout,
11
+ QFileDialog,
12
+ QGraphicsItem,
13
+ QGraphicsEllipseItem,
14
+ QGraphicsLineItem,
15
+ QTableWidgetItem,
16
+ QGraphicsPolygonItem,
17
+ QTableWidgetSelectionRange,
18
+ )
19
+ from PyQt6.QtGui import QPixmap, QColor, QPen, QBrush, QPolygonF, QIcon
20
+ from PyQt6.QtCore import Qt, QPointF, QTimer
21
+
22
+ # Relative imports for package structure
23
+ from .photo_viewer import PhotoViewer
24
+ from .sam_model import SamModel
25
+ from .utils import mask_to_pixmap
26
+ from .controls import ControlPanel, RightPanel
27
+ from .custom_file_system_model import CustomFileSystemModel
28
+ from .editable_vertex import EditableVertexItem
29
+ from .hoverable_polygon_item import HoverablePolygonItem
30
+ from .numeric_table_widget_item import NumericTableWidgetItem
31
+
32
+
33
+ class MainWindow(QMainWindow):
34
+ def __init__(self, sam_model):
35
+ super().__init__()
36
+ self.setWindowTitle("LazyLabel by DNC")
37
+
38
+ icon_path = os.path.join(
39
+ os.path.dirname(__file__), "demo_pictures", "logo2.png"
40
+ )
41
+ if os.path.exists(icon_path):
42
+ self.setWindowIcon(QIcon(icon_path))
43
+
44
+ self.setGeometry(50, 50, 1600, 900)
45
+
46
+ # The SamModel instance is now passed in
47
+ self.sam_model = sam_model
48
+ self.mode = "sam_points"
49
+ self.previous_mode = "sam_points"
50
+ self.current_image_path = None
51
+ self.current_file_index = None
52
+ self.next_class_id = 0
53
+
54
+ self.point_items, self.positive_points, self.negative_points = [], [], []
55
+ self.polygon_points, self.polygon_preview_items = [], []
56
+ self.rubber_band_line = None
57
+
58
+ self.segments, self.segment_items, self.highlight_items = [], {}, []
59
+ self.is_dragging_polygon, self.drag_start_pos, self.drag_initial_vertices = (
60
+ False,
61
+ None,
62
+ {},
63
+ )
64
+
65
+ self.control_panel = ControlPanel()
66
+ self.right_panel = RightPanel()
67
+ self.viewer = PhotoViewer(self)
68
+ self.viewer.setMouseTracking(True)
69
+ self.file_model = CustomFileSystemModel()
70
+ self.right_panel.file_tree.setModel(self.file_model)
71
+ self.right_panel.file_tree.setColumnWidth(0, 200)
72
+
73
+ main_layout = QHBoxLayout()
74
+ main_layout.addWidget(self.control_panel)
75
+ main_layout.addWidget(self.viewer, 1)
76
+ main_layout.addWidget(self.right_panel)
77
+ central_widget = QWidget()
78
+ central_widget.setLayout(main_layout)
79
+ self.setCentralWidget(central_widget)
80
+
81
+ self.control_panel.device_label.setText(
82
+ f"Device: {str(self.sam_model.device).upper()}"
83
+ )
84
+ self.setup_connections()
85
+ self.set_sam_mode()
86
+
87
+ def setup_connections(self):
88
+ self._original_mouse_press = self.viewer.scene().mousePressEvent
89
+ self._original_mouse_move = self.viewer.scene().mouseMoveEvent
90
+ self._original_mouse_release = self.viewer.scene().mouseReleaseEvent
91
+
92
+ self.viewer.scene().mousePressEvent = self.scene_mouse_press
93
+ self.viewer.scene().mouseMoveEvent = self.scene_mouse_move
94
+ self.viewer.scene().mouseReleaseEvent = self.scene_mouse_release
95
+
96
+ self.right_panel.btn_open_folder.clicked.connect(self.open_folder_dialog)
97
+ self.right_panel.file_tree.doubleClicked.connect(self.load_selected_image)
98
+ self.right_panel.btn_merge_selection.clicked.connect(
99
+ self.assign_selected_to_class
100
+ )
101
+ self.right_panel.btn_delete_selection.clicked.connect(
102
+ self.delete_selected_segments
103
+ )
104
+ self.right_panel.segment_table.itemSelectionChanged.connect(
105
+ self.highlight_selected_segments
106
+ )
107
+ self.right_panel.segment_table.itemChanged.connect(self.handle_class_id_change)
108
+ self.right_panel.btn_reassign_classes.clicked.connect(self.reassign_class_ids)
109
+ self.right_panel.class_filter_combo.currentIndexChanged.connect(
110
+ self.update_segment_table
111
+ )
112
+
113
+ self.control_panel.btn_sam_mode.clicked.connect(self.set_sam_mode)
114
+ self.control_panel.btn_polygon_mode.clicked.connect(self.set_polygon_mode)
115
+ self.control_panel.btn_selection_mode.clicked.connect(
116
+ self.toggle_selection_mode
117
+ )
118
+ self.control_panel.btn_clear_points.clicked.connect(self.clear_all_points)
119
+
120
+ def set_mode(self, mode_name, is_toggle=False):
121
+ if self.mode == "edit" and mode_name != "edit":
122
+ self.display_all_segments()
123
+ if not is_toggle and self.mode not in ["pan", "selection", "edit"]:
124
+ self.previous_mode = self.mode
125
+ self.mode = mode_name
126
+ self.control_panel.mode_label.setText(
127
+ f"Mode: {mode_name.replace('_', ' ').title()}"
128
+ )
129
+ self.clear_all_points()
130
+ self.viewer.setDragMode(
131
+ self.viewer.DragMode.ScrollHandDrag
132
+ if self.mode == "pan"
133
+ else self.viewer.DragMode.NoDrag
134
+ )
135
+
136
+ def set_sam_mode(self):
137
+ self.set_mode("sam_points")
138
+
139
+ def set_polygon_mode(self):
140
+ self.set_mode("polygon")
141
+
142
+ def toggle_mode(self, new_mode):
143
+ if self.mode == new_mode:
144
+ self.set_mode(self.previous_mode, is_toggle=True)
145
+ else:
146
+ if self.mode not in ["pan", "selection", "edit"]:
147
+ self.previous_mode = self.mode
148
+ self.set_mode(new_mode, is_toggle=True)
149
+
150
+ def toggle_pan_mode(self):
151
+ self.toggle_mode("pan")
152
+
153
+ def toggle_selection_mode(self):
154
+ self.toggle_mode("selection")
155
+
156
+ def toggle_edit_mode(self):
157
+ selected_indices = self.get_selected_segment_indices()
158
+ can_edit = any(
159
+ self.segments[i].get("type") == "Polygon" for i in selected_indices
160
+ )
161
+ if self.mode == "edit":
162
+ self.set_mode("selection", is_toggle=True)
163
+ elif self.mode == "selection" and can_edit:
164
+ self.set_mode("edit", is_toggle=True)
165
+ self.display_all_segments()
166
+
167
+ def open_folder_dialog(self):
168
+ folder_path = QFileDialog.getExistingDirectory(self, "Select Image Folder")
169
+ if folder_path:
170
+ self.right_panel.file_tree.setRootIndex(
171
+ self.file_model.setRootPath(folder_path)
172
+ )
173
+ self.viewer.setFocus()
174
+
175
+ def load_selected_image(self, index):
176
+ if not index.isValid():
177
+ return
178
+
179
+ self.current_file_index = index
180
+ path = self.file_model.filePath(index)
181
+
182
+ if os.path.isfile(path) and path.lower().endswith((".png", ".jpg", ".jpeg")):
183
+ self.current_image_path = path
184
+ pixmap = QPixmap(self.current_image_path)
185
+ if not pixmap.isNull():
186
+ self.reset_state()
187
+ self.viewer.set_photo(pixmap)
188
+ self.sam_model.set_image(self.current_image_path)
189
+ self.load_existing_mask()
190
+ self.viewer.setFocus()
191
+
192
+ def reset_state(self):
193
+ self.clear_all_points()
194
+ self.segments.clear()
195
+ self.next_class_id = 0
196
+ self.update_all_lists()
197
+ items_to_remove = [
198
+ item
199
+ for item in self.viewer.scene().items()
200
+ if item is not self.viewer._pixmap_item
201
+ ]
202
+ for item in items_to_remove:
203
+ self.viewer.scene().removeItem(item)
204
+ self.segment_items.clear()
205
+ self.highlight_items.clear()
206
+
207
+ def keyPressEvent(self, event):
208
+ key, mods = event.key(), event.modifiers()
209
+ if event.isAutoRepeat():
210
+ return
211
+ if key == Qt.Key.Key_W:
212
+ self.viewer.verticalScrollBar().setValue(
213
+ self.viewer.verticalScrollBar().value()
214
+ - int(self.viewer.height() * 0.1)
215
+ )
216
+ elif key == Qt.Key.Key_S and not mods:
217
+ self.viewer.verticalScrollBar().setValue(
218
+ self.viewer.verticalScrollBar().value()
219
+ + int(self.viewer.height() * 0.1)
220
+ )
221
+ elif key == Qt.Key.Key_A and not (mods & Qt.KeyboardModifier.ControlModifier):
222
+ self.viewer.horizontalScrollBar().setValue(
223
+ self.viewer.horizontalScrollBar().value()
224
+ - int(self.viewer.width() * 0.1)
225
+ )
226
+ elif key == Qt.Key.Key_D:
227
+ self.viewer.horizontalScrollBar().setValue(
228
+ self.viewer.horizontalScrollBar().value()
229
+ + int(self.viewer.width() * 0.1)
230
+ )
231
+ elif key == Qt.Key.Key_1:
232
+ self.set_sam_mode()
233
+ elif key == Qt.Key.Key_2:
234
+ self.set_polygon_mode()
235
+ elif key == Qt.Key.Key_E:
236
+ self.toggle_selection_mode()
237
+ elif key == Qt.Key.Key_Q:
238
+ self.toggle_pan_mode()
239
+ elif key == Qt.Key.Key_R:
240
+ self.toggle_edit_mode()
241
+ elif key == Qt.Key.Key_C:
242
+ self.clear_all_points()
243
+ elif key == Qt.Key.Key_V or key == Qt.Key.Key_Backspace:
244
+ self.delete_selected_segments()
245
+ elif key == Qt.Key.Key_M:
246
+ self.assign_selected_to_class()
247
+ elif key == Qt.Key.Key_Z and mods == Qt.KeyboardModifier.ControlModifier:
248
+ self.undo_last_action()
249
+ elif key == Qt.Key.Key_A and mods == Qt.KeyboardModifier.ControlModifier:
250
+ self.right_panel.segment_table.selectAll()
251
+ elif key == Qt.Key.Key_Space:
252
+ self.save_current_segment()
253
+ elif key == Qt.Key.Key_Return or key == Qt.Key.Key_Enter:
254
+ self.save_output_to_npz()
255
+
256
+ def scene_mouse_press(self, event):
257
+ self._original_mouse_press(event)
258
+ if event.isAccepted():
259
+ return
260
+ pos = event.scenePos()
261
+ if (
262
+ self.viewer._pixmap_item.pixmap().isNull()
263
+ or not self.viewer._pixmap_item.pixmap().rect().contains(pos.toPoint())
264
+ ):
265
+ return
266
+ if self.mode == "sam_points":
267
+ if event.button() == Qt.MouseButton.LeftButton:
268
+ self.add_point(pos, positive=True)
269
+ self.update_segmentation()
270
+ elif event.button() == Qt.MouseButton.RightButton:
271
+ self.add_point(pos, positive=False)
272
+ self.update_segmentation()
273
+ elif self.mode == "polygon":
274
+ if event.button() == Qt.MouseButton.LeftButton:
275
+ self.handle_polygon_click(pos)
276
+ elif self.mode == "selection":
277
+ if event.button() == Qt.MouseButton.LeftButton:
278
+ self.handle_segment_selection_click(pos)
279
+ elif self.mode == "edit":
280
+ self.drag_start_pos = pos
281
+ self.is_dragging_polygon = True
282
+ selected_indices = self.get_selected_segment_indices()
283
+ self.drag_initial_vertices = {
284
+ i: list(self.segments[i]["vertices"])
285
+ for i in selected_indices
286
+ if self.segments[i].get("type") == "Polygon"
287
+ }
288
+
289
+ def scene_mouse_move(self, event):
290
+ pos = event.scenePos()
291
+ if self.mode == "edit" and self.is_dragging_polygon:
292
+ delta = pos - self.drag_start_pos
293
+ for i, initial_verts in self.drag_initial_vertices.items():
294
+ self.segments[i]["vertices"] = [
295
+ QPointF(v.x() + delta.x(), v.y() + delta.y()) for v in initial_verts
296
+ ]
297
+ self.update_polygon_visuals(i)
298
+ elif self.mode == "polygon" and self.polygon_points:
299
+ if self.rubber_band_line is None:
300
+ self.rubber_band_line = QGraphicsLineItem()
301
+ self.rubber_band_line.setPen(
302
+ QPen(Qt.GlobalColor.white, 2, Qt.PenStyle.DotLine)
303
+ )
304
+ self.viewer.scene().addItem(self.rubber_band_line)
305
+ self.rubber_band_line.setLine(
306
+ self.polygon_points[-1].x(),
307
+ self.polygon_points[-1].y(),
308
+ pos.x(),
309
+ pos.y(),
310
+ )
311
+ self.rubber_band_line.show()
312
+ else:
313
+ self._original_mouse_move(event)
314
+
315
+ def scene_mouse_release(self, event):
316
+ if self.mode == "edit" and self.is_dragging_polygon:
317
+ self.is_dragging_polygon = False
318
+ self.drag_initial_vertices.clear()
319
+ self._original_mouse_release(event)
320
+
321
+ def undo_last_action(self):
322
+ if self.mode == "polygon" and self.polygon_points:
323
+ self.polygon_points.pop()
324
+ if self.polygon_preview_items:
325
+ self.viewer.scene().removeItem(self.polygon_preview_items.pop())
326
+ self.draw_polygon_preview()
327
+ elif self.mode == "sam_points" and self.point_items:
328
+ item_to_remove = self.point_items.pop()
329
+ point_pos = item_to_remove.rect().topLeft() + QPointF(4, 4)
330
+ point_coords = [int(point_pos.x()), int(point_pos.y())]
331
+ if point_coords in self.positive_points:
332
+ self.positive_points.remove(point_coords)
333
+ elif point_coords in self.negative_points:
334
+ self.negative_points.remove(point_coords)
335
+ self.viewer.scene().removeItem(item_to_remove)
336
+ self.update_segmentation()
337
+
338
+ def finalize_polygon(self):
339
+ if len(self.polygon_points) < 3:
340
+ return
341
+ if self.rubber_band_line:
342
+ self.viewer.scene().removeItem(self.rubber_band_line)
343
+ self.rubber_band_line = None
344
+ self.segments.append(
345
+ {
346
+ "vertices": list(self.polygon_points),
347
+ "type": "Polygon",
348
+ "mask": None,
349
+ "class_id": self.next_class_id,
350
+ }
351
+ )
352
+ self.next_class_id += 1
353
+ self.polygon_points.clear()
354
+ for item in self.polygon_preview_items:
355
+ self.viewer.scene().removeItem(item)
356
+ self.polygon_preview_items.clear()
357
+ self.update_all_lists()
358
+
359
+ def handle_segment_selection_click(self, pos):
360
+ x, y = int(pos.x()), int(pos.y())
361
+ for i in range(len(self.segments) - 1, -1, -1):
362
+ seg = self.segments[i]
363
+ mask = (
364
+ self.rasterize_polygon(seg["vertices"])
365
+ if seg["type"] == "Polygon"
366
+ else seg.get("mask")
367
+ )
368
+ if (
369
+ mask is not None
370
+ and y < mask.shape[0]
371
+ and x < mask.shape[1]
372
+ and mask[y, x]
373
+ ):
374
+ for j in range(self.right_panel.segment_table.rowCount()):
375
+ item = self.right_panel.segment_table.item(j, 0)
376
+ if item and item.data(Qt.ItemDataRole.UserRole) == i:
377
+ table = self.right_panel.segment_table
378
+ is_selected = table.item(j, 0).isSelected()
379
+ range_to_select = QTableWidgetSelectionRange(
380
+ j, 0, j, table.columnCount() - 1
381
+ )
382
+ table.setRangeSelected(range_to_select, not is_selected)
383
+ return
384
+ self.viewer.setFocus()
385
+
386
+ def assign_selected_to_class(self):
387
+ selected_indices = self.get_selected_segment_indices()
388
+ if not selected_indices:
389
+ return
390
+ target_class_id = self.segments[selected_indices[0]]["class_id"]
391
+ for i in selected_indices:
392
+ self.segments[i]["class_id"] = target_class_id
393
+ self.update_all_lists()
394
+ self.viewer.setFocus()
395
+
396
+ def rasterize_polygon(self, vertices):
397
+ if not vertices or self.viewer._pixmap_item.pixmap().isNull():
398
+ return None
399
+ h, w = (
400
+ self.viewer._pixmap_item.pixmap().height(),
401
+ self.viewer._pixmap_item.pixmap().width(),
402
+ )
403
+ points_np = np.array([[p.x(), p.y()] for p in vertices], dtype=np.int32)
404
+ mask = np.zeros((h, w), dtype=np.uint8)
405
+ cv2.fillPoly(mask, [points_np], 1)
406
+ return mask.astype(bool)
407
+
408
+ def display_all_segments(self):
409
+ for i, items in self.segment_items.items():
410
+ for item in items:
411
+ self.viewer.scene().removeItem(item)
412
+ self.segment_items.clear()
413
+ selected_indices = self.get_selected_segment_indices()
414
+
415
+ unique_class_ids = sorted(
416
+ list(
417
+ {
418
+ seg.get("class_id")
419
+ for seg in self.segments
420
+ if seg.get("class_id") is not None
421
+ }
422
+ )
423
+ )
424
+ num_classes = len(unique_class_ids) if unique_class_ids else 1
425
+ class_id_to_hue_index = {
426
+ class_id: i for i, class_id in enumerate(unique_class_ids)
427
+ }
428
+
429
+ for i, seg_dict in enumerate(self.segments):
430
+ self.segment_items[i] = []
431
+ class_id = seg_dict.get("class_id", 0)
432
+ hue_index = class_id_to_hue_index.get(class_id, 0)
433
+ hue = int((hue_index * 360 / num_classes)) % 360
434
+ base_color = QColor.fromHsv(hue, 220, 220)
435
+
436
+ if seg_dict["type"] == "Polygon":
437
+ poly_item = HoverablePolygonItem(QPolygonF(seg_dict["vertices"]))
438
+ default_brush = QBrush(
439
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 70)
440
+ )
441
+ hover_brush = QBrush(
442
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 170)
443
+ )
444
+ poly_item.set_brushes(default_brush, hover_brush)
445
+ poly_item.setPen(QPen(Qt.GlobalColor.transparent))
446
+ self.viewer.scene().addItem(poly_item)
447
+ self.segment_items[i].append(poly_item)
448
+ vertex_color = QBrush(base_color)
449
+ for v in seg_dict["vertices"]:
450
+ dot = QGraphicsEllipseItem(v.x() - 3, v.y() - 3, 6, 6)
451
+ dot.setBrush(vertex_color)
452
+ self.viewer.scene().addItem(dot)
453
+ self.segment_items[i].append(dot)
454
+ if self.mode == "edit" and i in selected_indices:
455
+ for idx, v in enumerate(seg_dict["vertices"]):
456
+ vertex_item = EditableVertexItem(self, i, idx, -4, -4, 8, 8)
457
+ vertex_item.setPos(v)
458
+ self.viewer.scene().addItem(vertex_item)
459
+ self.segment_items[i].append(vertex_item)
460
+ elif seg_dict.get("mask") is not None:
461
+ pixmap = mask_to_pixmap(seg_dict["mask"], base_color.getRgb()[:3])
462
+ pixmap_item = self.viewer.scene().addPixmap(pixmap)
463
+ pixmap_item.setZValue(i + 1)
464
+ self.segment_items[i].append(pixmap_item)
465
+ self.highlight_selected_segments()
466
+
467
+ def update_vertex_pos(self, seg_idx, vtx_idx, new_pos):
468
+ self.segments[seg_idx]["vertices"][vtx_idx] = new_pos
469
+ self.update_polygon_visuals(seg_idx)
470
+
471
+ def update_polygon_visuals(self, segment_index):
472
+ items = self.segment_items.get(segment_index, [])
473
+ for item in items:
474
+ if isinstance(item, HoverablePolygonItem):
475
+ item.setPolygon(QPolygonF(self.segments[segment_index]["vertices"]))
476
+ break
477
+
478
+ def highlight_selected_segments(self):
479
+ if hasattr(self, "highlight_items"):
480
+ for item in self.highlight_items:
481
+ self.viewer.scene().removeItem(item)
482
+ self.highlight_items.clear()
483
+ selected_indices = self.get_selected_segment_indices()
484
+ for i in selected_indices:
485
+ seg = self.segments[i]
486
+ mask = (
487
+ self.rasterize_polygon(seg["vertices"])
488
+ if seg["type"] == "Polygon"
489
+ else seg.get("mask")
490
+ )
491
+ if mask is not None:
492
+ pixmap = mask_to_pixmap(mask, (255, 255, 255))
493
+ highlight_item = self.viewer.scene().addPixmap(pixmap)
494
+ highlight_item.setZValue(100)
495
+ self.highlight_items.append(highlight_item)
496
+
497
+ def update_all_lists(self):
498
+ self.update_class_filter_combo()
499
+ self.update_segment_table()
500
+ self.update_class_list()
501
+ self.display_all_segments()
502
+
503
+ def update_segment_table(self):
504
+ table = self.right_panel.segment_table
505
+ table.blockSignals(True)
506
+ selected_indices = self.get_selected_segment_indices()
507
+ table.clearContents()
508
+ table.setRowCount(0)
509
+ filter_text = self.right_panel.class_filter_combo.currentText()
510
+ show_all = filter_text == "All Classes"
511
+ filter_class_id = -1
512
+ if not show_all:
513
+ try:
514
+ filter_class_id = int(filter_text.split(" ")[1])
515
+ except (ValueError, IndexError):
516
+ pass
517
+
518
+ display_segments = []
519
+ for i, seg in enumerate(self.segments):
520
+ if show_all or seg.get("class_id") == filter_class_id:
521
+ display_segments.append((i, seg))
522
+
523
+ table.setRowCount(len(display_segments))
524
+
525
+ unique_class_ids = sorted(
526
+ list(
527
+ {
528
+ s.get("class_id")
529
+ for s in self.segments
530
+ if s.get("class_id") is not None
531
+ }
532
+ )
533
+ )
534
+ num_classes = len(unique_class_ids) if unique_class_ids else 1
535
+ class_id_to_hue_index = {cid: i for i, cid in enumerate(unique_class_ids)}
536
+
537
+ for row, (original_index, seg) in enumerate(display_segments):
538
+ class_id = seg.get("class_id", 0)
539
+ hue_index = class_id_to_hue_index.get(class_id, 0)
540
+ hue = int((hue_index * 360 / num_classes)) % 360
541
+ color = QColor.fromHsv(hue, 150, 100)
542
+
543
+ index_item = NumericTableWidgetItem(str(original_index + 1))
544
+ class_item = NumericTableWidgetItem(str(class_id))
545
+ type_item = QTableWidgetItem(seg["type"])
546
+
547
+ index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
548
+ type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
549
+ index_item.setData(Qt.ItemDataRole.UserRole, original_index)
550
+
551
+ table.setItem(row, 0, index_item)
552
+ table.setItem(row, 1, class_item)
553
+ table.setItem(row, 2, type_item)
554
+ for col in range(3):
555
+ table.item(row, col).setBackground(QBrush(color))
556
+
557
+ table.setSortingEnabled(False)
558
+ for row in range(table.rowCount()):
559
+ if table.item(row, 0).data(Qt.ItemDataRole.UserRole) in selected_indices:
560
+ table.selectRow(row)
561
+ table.setSortingEnabled(True)
562
+
563
+ table.blockSignals(False)
564
+ self.viewer.setFocus()
565
+
566
+ def update_class_list(self):
567
+ class_table = self.right_panel.class_table
568
+ class_table.blockSignals(True)
569
+ unique_class_ids = sorted(
570
+ list(
571
+ {
572
+ seg.get("class_id")
573
+ for seg in self.segments
574
+ if seg.get("class_id") is not None
575
+ }
576
+ )
577
+ )
578
+ class_table.setRowCount(len(unique_class_ids))
579
+ num_classes = len(unique_class_ids) if unique_class_ids else 1
580
+ class_id_to_hue_index = {
581
+ class_id: i for i, class_id in enumerate(unique_class_ids)
582
+ }
583
+ for row, cid in enumerate(unique_class_ids):
584
+ item = QTableWidgetItem(str(cid))
585
+ item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
586
+ hue_index = class_id_to_hue_index.get(cid, 0)
587
+ hue = int((hue_index * 360 / num_classes)) % 360
588
+ color = QColor.fromHsv(hue, 150, 100)
589
+ item.setBackground(QBrush(color))
590
+ class_table.setItem(row, 0, item)
591
+ class_table.blockSignals(False)
592
+
593
+ def update_class_filter_combo(self):
594
+ combo = self.right_panel.class_filter_combo
595
+ unique_class_ids = sorted(
596
+ list(
597
+ {
598
+ seg.get("class_id")
599
+ for seg in self.segments
600
+ if seg.get("class_id") is not None
601
+ }
602
+ )
603
+ )
604
+ current_selection = combo.currentText()
605
+ combo.blockSignals(True)
606
+ combo.clear()
607
+ combo.addItem("All Classes")
608
+ combo.addItems([f"Class {cid}" for cid in unique_class_ids])
609
+ if combo.findText(current_selection) > -1:
610
+ combo.setCurrentText(current_selection)
611
+ else:
612
+ combo.setCurrentIndex(0)
613
+ combo.blockSignals(False)
614
+
615
+ def reassign_class_ids(self):
616
+ class_table = self.right_panel.class_table
617
+ ordered_ids = [
618
+ int(class_table.item(row, 0).text())
619
+ for row in range(class_table.rowCount())
620
+ if class_table.item(row, 0) is not None
621
+ ]
622
+ id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
623
+ for seg in self.segments:
624
+ old_id = seg.get("class_id")
625
+ if old_id in id_map:
626
+ seg["class_id"] = id_map[old_id]
627
+ self.next_class_id = len(ordered_ids)
628
+ self.update_all_lists()
629
+ self.viewer.setFocus()
630
+
631
+ def handle_class_id_change(self, item):
632
+ if item.column() != 1:
633
+ return
634
+ table = self.right_panel.segment_table
635
+ table.blockSignals(True)
636
+ try:
637
+ new_class_id = int(item.text())
638
+ original_index = table.item(item.row(), 0).data(Qt.ItemDataRole.UserRole)
639
+ self.segments[original_index]["class_id"] = new_class_id
640
+ if new_class_id >= self.next_class_id:
641
+ self.next_class_id = new_class_id + 1
642
+ self.update_all_lists()
643
+ except (ValueError, TypeError):
644
+ original_index = table.item(item.row(), 0).data(Qt.ItemDataRole.UserRole)
645
+ item.setText(str(self.segments[original_index]["class_id"]))
646
+ table.blockSignals(False)
647
+ self.viewer.setFocus()
648
+
649
+ def get_selected_segment_indices(self):
650
+ table = self.right_panel.segment_table
651
+ selected_items = table.selectedItems()
652
+ selected_rows = sorted(list({item.row() for item in selected_items}))
653
+ return [
654
+ table.item(row, 0).data(Qt.ItemDataRole.UserRole)
655
+ for row in selected_rows
656
+ if table.item(row, 0)
657
+ ]
658
+
659
+ def save_output_to_npz(self):
660
+ if not self.segments or not self.current_image_path:
661
+ return
662
+ self.right_panel.status_label.setText("Saving...")
663
+ QApplication.processEvents()
664
+
665
+ output_path = os.path.splitext(self.current_image_path)[0] + ".npz"
666
+ h, w = (
667
+ self.viewer._pixmap_item.pixmap().height(),
668
+ self.viewer._pixmap_item.pixmap().width(),
669
+ )
670
+ unique_class_ids = sorted(
671
+ list(
672
+ {
673
+ seg["class_id"]
674
+ for seg in self.segments
675
+ if seg.get("class_id") is not None
676
+ }
677
+ )
678
+ )
679
+ if not unique_class_ids:
680
+ self.right_panel.status_label.setText("Save failed: No classes.")
681
+ QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
682
+ return
683
+
684
+ id_map = {old_id: new_id for new_id, old_id in enumerate(unique_class_ids)}
685
+ num_final_classes = len(unique_class_ids)
686
+ final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
687
+
688
+ for seg in self.segments:
689
+ class_id = seg.get("class_id")
690
+ if class_id not in id_map:
691
+ continue
692
+ new_channel_idx = id_map[class_id]
693
+ mask = (
694
+ self.rasterize_polygon(seg["vertices"])
695
+ if seg["type"] == "Polygon"
696
+ else seg.get("mask")
697
+ )
698
+ if mask is not None:
699
+ final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
700
+ final_mask_tensor[:, :, new_channel_idx], mask
701
+ )
702
+
703
+ np.savez_compressed(output_path, mask=final_mask_tensor.astype(np.uint8))
704
+ self.file_model.setRootPath(self.file_model.rootPath())
705
+
706
+ self.right_panel.status_label.setText("Saved!")
707
+ QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
708
+
709
+ def save_current_segment(self):
710
+ if (
711
+ self.mode != "sam_points"
712
+ or not hasattr(self, "preview_mask_item")
713
+ or not self.preview_mask_item
714
+ ):
715
+ return
716
+ mask = self.sam_model.predict(self.positive_points, self.negative_points)
717
+ if mask is not None:
718
+ self.segments.append(
719
+ {
720
+ "mask": mask,
721
+ "type": "SAM",
722
+ "vertices": None,
723
+ "class_id": self.next_class_id,
724
+ }
725
+ )
726
+ self.next_class_id += 1
727
+ self.clear_all_points()
728
+ self.update_all_lists()
729
+
730
+ def delete_selected_segments(self):
731
+ selected_indices = self.get_selected_segment_indices()
732
+ if not selected_indices:
733
+ return
734
+ for i in sorted(selected_indices, reverse=True):
735
+ del self.segments[i]
736
+ self.update_all_lists()
737
+ self.viewer.setFocus()
738
+
739
+ def load_existing_mask(self):
740
+ if not self.current_image_path:
741
+ return
742
+ npz_path = os.path.splitext(self.current_image_path)[0] + ".npz"
743
+ if os.path.exists(npz_path):
744
+ with np.load(npz_path) as data:
745
+ if "mask" in data:
746
+ mask_data = data["mask"]
747
+ if mask_data.ndim == 2:
748
+ mask_data = np.expand_dims(mask_data, axis=-1)
749
+ num_classes = mask_data.shape[2]
750
+ for i in range(num_classes):
751
+ class_mask = mask_data[:, :, i].astype(bool)
752
+ if np.any(class_mask):
753
+ self.segments.append(
754
+ {
755
+ "mask": class_mask,
756
+ "type": "Loaded",
757
+ "vertices": None,
758
+ "class_id": i,
759
+ }
760
+ )
761
+ self.next_class_id = num_classes
762
+ self.update_all_lists()
763
+
764
+ def add_point(self, pos, positive):
765
+ point_list = self.positive_points if positive else self.negative_points
766
+ point_list.append([int(pos.x()), int(pos.y())])
767
+ color = Qt.GlobalColor.green if positive else Qt.GlobalColor.red
768
+ point_item = QGraphicsEllipseItem(pos.x() - 4, pos.y() - 4, 8, 8)
769
+ point_item.setBrush(QBrush(color))
770
+ point_item.setPen(QPen(Qt.GlobalColor.white))
771
+ self.viewer.scene().addItem(point_item)
772
+ self.point_items.append(point_item)
773
+
774
+ def update_segmentation(self):
775
+ if hasattr(self, "preview_mask_item") and self.preview_mask_item:
776
+ self.viewer.scene().removeItem(self.preview_mask_item)
777
+ if not self.positive_points:
778
+ return
779
+ mask = self.sam_model.predict(self.positive_points, self.negative_points)
780
+ if mask is not None:
781
+ pixmap = mask_to_pixmap(mask, (255, 255, 0))
782
+ self.preview_mask_item = self.viewer.scene().addPixmap(pixmap)
783
+ self.preview_mask_item.setZValue(50)
784
+
785
+ def clear_all_points(self):
786
+ if self.rubber_band_line:
787
+ self.viewer.scene().removeItem(self.rubber_band_line)
788
+ self.rubber_band_line = None
789
+ self.positive_points.clear()
790
+ self.negative_points.clear()
791
+ for item in self.point_items:
792
+ self.viewer.scene().removeItem(item)
793
+ self.point_items.clear()
794
+ self.polygon_points.clear()
795
+ for item in self.polygon_preview_items:
796
+ self.viewer.scene().removeItem(item)
797
+ self.polygon_preview_items.clear()
798
+ if hasattr(self, "preview_mask_item") and self.preview_mask_item:
799
+ self.viewer.scene().removeItem(self.preview_mask_item)
800
+ self.preview_mask_item = None
801
+
802
+ def handle_polygon_click(self, pos):
803
+ if self.polygon_points and (
804
+ (
805
+ (pos.x() - self.polygon_points[0].x()) ** 2
806
+ + (pos.y() - self.polygon_points[0].y()) ** 2
807
+ )
808
+ < 25
809
+ ):
810
+ if len(self.polygon_points) > 2:
811
+ self.finalize_polygon()
812
+ return
813
+ self.polygon_points.append(pos)
814
+ dot = QGraphicsEllipseItem(pos.x() - 2, pos.y() - 2, 4, 4)
815
+ dot.setBrush(QBrush(Qt.GlobalColor.blue))
816
+ dot.setPen(QPen(Qt.GlobalColor.cyan))
817
+ self.viewer.scene().addItem(dot)
818
+ self.polygon_preview_items.append(dot)
819
+ self.draw_polygon_preview()
820
+
821
+ def draw_polygon_preview(self):
822
+ if self.rubber_band_line:
823
+ self.viewer.scene().removeItem(self.rubber_band_line)
824
+ self.rubber_band_line = None
825
+ for item in self.polygon_preview_items:
826
+ if not isinstance(item, QGraphicsEllipseItem):
827
+ self.viewer.scene().removeItem(item)
828
+ self.polygon_preview_items = [
829
+ item
830
+ for item in self.polygon_preview_items
831
+ if isinstance(item, QGraphicsEllipseItem)
832
+ ]
833
+
834
+ if len(self.polygon_points) > 2:
835
+ preview_poly = QGraphicsPolygonItem(QPolygonF(self.polygon_points))
836
+ preview_poly.setBrush(QBrush(QColor(0, 255, 255, 100)))
837
+ preview_poly.setPen(QPen(Qt.GlobalColor.transparent))
838
+ self.viewer.scene().addItem(preview_poly)
839
+ self.polygon_preview_items.append(preview_poly)
840
+
841
+ if len(self.polygon_points) > 1:
842
+ for i in range(len(self.polygon_points) - 1):
843
+ line = QGraphicsLineItem(
844
+ self.polygon_points[i].x(),
845
+ self.polygon_points[i].y(),
846
+ self.polygon_points[i + 1].x(),
847
+ self.polygon_points[i + 1].y(),
848
+ )
849
+ line.setPen(QPen(Qt.GlobalColor.cyan, 2))
850
+ self.viewer.scene().addItem(line)
851
+ self.polygon_preview_items.append(line)
852
+
853
+
854
+ def main():
855
+ app = QApplication(sys.argv)
856
+ qdarktheme.setup_theme()
857
+ sam_model = SamModel(model_type="vit_h") # one-time check/download
858
+ main_win = MainWindow(sam_model)
859
+ main_win.show()
860
+ sys.exit(app.exec())
861
+
862
+
863
+ if __name__ == "__main__":
864
+ main()