lazylabel-gui 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lazylabel/main.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
  import numpy as np
4
4
  import qdarktheme
5
5
  import cv2
6
+ import json
6
7
  from PyQt6.QtWidgets import (
7
8
  QApplication,
8
9
  QMainWindow,
@@ -15,9 +16,21 @@ from PyQt6.QtWidgets import (
15
16
  QTableWidgetItem,
16
17
  QGraphicsPolygonItem,
17
18
  QTableWidgetSelectionRange,
19
+ QSpacerItem,
20
+ QHeaderView,
18
21
  )
19
- from PyQt6.QtGui import QPixmap, QColor, QPen, QBrush, QPolygonF, QIcon
20
- from PyQt6.QtCore import Qt, QPointF, QTimer
22
+ from PyQt6.QtGui import (
23
+ QPixmap,
24
+ QColor,
25
+ QPen,
26
+ QBrush,
27
+ QPolygonF,
28
+ QIcon,
29
+ QCursor,
30
+ QKeySequence,
31
+ QShortcut,
32
+ )
33
+ from PyQt6.QtCore import Qt, QPointF, QTimer, QModelIndex
21
34
 
22
35
  from .photo_viewer import PhotoViewer
23
36
  from .sam_model import SamModel
@@ -26,6 +39,7 @@ from .controls import ControlPanel, RightPanel
26
39
  from .custom_file_system_model import CustomFileSystemModel
27
40
  from .editable_vertex import EditableVertexItem
28
41
  from .hoverable_polygon_item import HoverablePolygonItem
42
+ from .hoverable_pixelmap_item import HoverablePixmapItem
29
43
  from .numeric_table_widget_item import NumericTableWidgetItem
30
44
 
31
45
 
@@ -34,7 +48,9 @@ class MainWindow(QMainWindow):
34
48
  super().__init__()
35
49
  self.setWindowTitle("LazyLabel by DNC")
36
50
 
37
- icon_path = os.path.join(os.path.dirname(__file__), "demo_pictures", "logo2.png")
51
+ icon_path = os.path.join(
52
+ os.path.dirname(__file__), "demo_pictures", "logo2.png"
53
+ )
38
54
  if os.path.exists(icon_path):
39
55
  self.setWindowIcon(QIcon(icon_path))
40
56
 
@@ -44,8 +60,18 @@ class MainWindow(QMainWindow):
44
60
  self.mode = "sam_points"
45
61
  self.previous_mode = "sam_points"
46
62
  self.current_image_path = None
47
- self.current_file_index = None
63
+ self.current_file_index = QModelIndex()
64
+
48
65
  self.next_class_id = 0
66
+ self.class_aliases = {}
67
+
68
+ self._original_point_radius = 0.3
69
+ self._original_line_thickness = 0.5
70
+ self.point_radius = self._original_point_radius
71
+ self.line_thickness = self._original_line_thickness
72
+
73
+ self.pan_multiplier = 1.0
74
+ self.polygon_join_threshold = 2
49
75
 
50
76
  self.point_items, self.positive_points, self.negative_points = [], [], []
51
77
  self.polygon_points, self.polygon_preview_items = [], []
@@ -65,6 +91,10 @@ class MainWindow(QMainWindow):
65
91
  self.file_model = CustomFileSystemModel()
66
92
  self.right_panel.file_tree.setModel(self.file_model)
67
93
  self.right_panel.file_tree.setColumnWidth(0, 200)
94
+ file_tree = self.right_panel.file_tree
95
+ header = file_tree.header()
96
+ header.setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
97
+ header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
68
98
 
69
99
  main_layout = QHBoxLayout()
70
100
  main_layout.addWidget(self.control_panel)
@@ -74,9 +104,12 @@ class MainWindow(QMainWindow):
74
104
  central_widget.setLayout(main_layout)
75
105
  self.setCentralWidget(central_widget)
76
106
 
77
- self.control_panel.device_label.setText(f"Device: {str(self.sam_model.device).upper()}")
107
+ self.control_panel.device_label.setText(
108
+ f"Device: {str(self.sam_model.device).upper()}"
109
+ )
78
110
  self.setup_connections()
79
111
  self.set_sam_mode()
112
+ self.set_annotation_size(10)
80
113
 
81
114
  def setup_connections(self):
82
115
  self._original_mouse_press = self.viewer.scene().mousePressEvent
@@ -89,25 +122,166 @@ class MainWindow(QMainWindow):
89
122
 
90
123
  self.right_panel.btn_open_folder.clicked.connect(self.open_folder_dialog)
91
124
  self.right_panel.file_tree.doubleClicked.connect(self.load_selected_image)
92
- self.right_panel.btn_merge_selection.clicked.connect(self.assign_selected_to_class)
93
- self.right_panel.btn_delete_selection.clicked.connect(self.delete_selected_segments)
94
- self.right_panel.segment_table.itemSelectionChanged.connect(self.highlight_selected_segments)
95
- self.right_panel.segment_table.itemChanged.connect(self.handle_class_id_change)
125
+ self.right_panel.btn_merge_selection.clicked.connect(
126
+ self.assign_selected_to_class
127
+ )
128
+ self.right_panel.btn_delete_selection.clicked.connect(
129
+ self.delete_selected_segments
130
+ )
131
+ self.right_panel.segment_table.itemSelectionChanged.connect(
132
+ self.highlight_selected_segments
133
+ )
134
+ self.right_panel.class_table.itemChanged.connect(self.handle_alias_change)
96
135
  self.right_panel.btn_reassign_classes.clicked.connect(self.reassign_class_ids)
97
- self.right_panel.class_filter_combo.currentIndexChanged.connect(self.update_segment_table)
136
+ self.right_panel.class_filter_combo.currentIndexChanged.connect(
137
+ self.update_segment_table
138
+ )
98
139
 
99
140
  self.control_panel.btn_sam_mode.clicked.connect(self.set_sam_mode)
100
141
  self.control_panel.btn_polygon_mode.clicked.connect(self.set_polygon_mode)
101
- self.control_panel.btn_selection_mode.clicked.connect(self.toggle_selection_mode)
142
+ self.control_panel.btn_selection_mode.clicked.connect(
143
+ self.toggle_selection_mode
144
+ )
102
145
  self.control_panel.btn_clear_points.clicked.connect(self.clear_all_points)
146
+ self.control_panel.btn_fit_view.clicked.connect(self.viewer.fitInView)
147
+
148
+ self.control_panel.size_slider.valueChanged.connect(self.set_annotation_size)
149
+ self.control_panel.pan_slider.valueChanged.connect(self.set_pan_multiplier)
150
+ self.control_panel.join_slider.valueChanged.connect(
151
+ self.set_polygon_join_threshold
152
+ )
153
+
154
+ self.control_panel.chk_save_npz.stateChanged.connect(
155
+ self.handle_save_checkbox_change
156
+ )
157
+ self.control_panel.chk_save_txt.stateChanged.connect(
158
+ self.handle_save_checkbox_change
159
+ )
160
+
161
+ self.control_panel.btn_toggle_visibility.clicked.connect(self.toggle_left_panel)
162
+ self.right_panel.btn_toggle_visibility.clicked.connect(self.toggle_right_panel)
163
+
164
+ QShortcut(QKeySequence(Qt.Key.Key_Right), self, self.load_next_image)
165
+ QShortcut(QKeySequence(Qt.Key.Key_Left), self, self.load_previous_image)
166
+ QShortcut(QKeySequence(Qt.Key.Key_1), self, self.set_sam_mode)
167
+ QShortcut(QKeySequence(Qt.Key.Key_2), self, self.set_polygon_mode)
168
+ QShortcut(QKeySequence(Qt.Key.Key_E), self, self.toggle_selection_mode)
169
+ QShortcut(QKeySequence(Qt.Key.Key_Q), self, self.toggle_pan_mode)
170
+ QShortcut(QKeySequence(Qt.Key.Key_R), self, self.toggle_edit_mode)
171
+ QShortcut(QKeySequence(Qt.Key.Key_C), self, self.clear_all_points)
172
+ QShortcut(QKeySequence(Qt.Key.Key_Escape), self, self.handle_escape_press)
173
+ QShortcut(QKeySequence(Qt.Key.Key_V), self, self.delete_selected_segments)
174
+ QShortcut(
175
+ QKeySequence(Qt.Key.Key_Backspace), self, self.delete_selected_segments
176
+ )
177
+ QShortcut(QKeySequence(Qt.Key.Key_M), self, self.handle_merge_press)
178
+ QShortcut(QKeySequence("Ctrl+Z"), self, self.undo_last_action)
179
+ QShortcut(
180
+ QKeySequence("Ctrl+A"), self, self.right_panel.segment_table.selectAll
181
+ )
182
+ QShortcut(QKeySequence(Qt.Key.Key_Space), self, self.handle_space_press)
183
+ QShortcut(QKeySequence(Qt.Key.Key_Return), self, self.handle_enter_press)
184
+ QShortcut(QKeySequence(Qt.Key.Key_Enter), self, self.handle_enter_press)
185
+ QShortcut(QKeySequence(Qt.Key.Key_Period), self, self.viewer.fitInView)
186
+
187
+ def toggle_left_panel(self):
188
+ is_visible = self.control_panel.main_controls_widget.isVisible()
189
+ self.control_panel.main_controls_widget.setVisible(not is_visible)
190
+ if is_visible:
191
+ self.control_panel.btn_toggle_visibility.setText("> Show")
192
+ self.control_panel.setFixedWidth(
193
+ self.control_panel.btn_toggle_visibility.sizeHint().width() + 20
194
+ )
195
+ else:
196
+ self.control_panel.btn_toggle_visibility.setText("< Hide")
197
+ self.control_panel.setFixedWidth(250)
198
+
199
+ def toggle_right_panel(self):
200
+ is_visible = self.right_panel.main_controls_widget.isVisible()
201
+ self.right_panel.main_controls_widget.setVisible(not is_visible)
202
+ layout = self.right_panel.v_layout
203
+
204
+ if is_visible: # Content is now hidden
205
+ layout.addStretch(1)
206
+ self.right_panel.btn_toggle_visibility.setText("< Show")
207
+ self.right_panel.setFixedWidth(
208
+ self.right_panel.btn_toggle_visibility.sizeHint().width() + 20
209
+ )
210
+ else: # Content is now visible
211
+ # Remove the stretch so the content can expand
212
+ for i in range(layout.count()):
213
+ item = layout.itemAt(i)
214
+ if isinstance(item, QSpacerItem):
215
+ layout.removeItem(item)
216
+ break
217
+ self.right_panel.btn_toggle_visibility.setText("Hide >")
218
+ self.right_panel.setFixedWidth(350)
219
+
220
+ def handle_save_checkbox_change(self):
221
+ is_npz_checked = self.control_panel.chk_save_npz.isChecked()
222
+ is_txt_checked = self.control_panel.chk_save_txt.isChecked()
223
+
224
+ if not is_npz_checked and not is_txt_checked:
225
+ sender = self.sender()
226
+ if sender == self.control_panel.chk_save_npz:
227
+ self.control_panel.chk_save_txt.setChecked(True)
228
+ else:
229
+ self.control_panel.chk_save_npz.setChecked(True)
230
+
231
+ def set_annotation_size(self, value):
232
+ multiplier = value / 10.0
233
+ self.point_radius = self._original_point_radius * multiplier
234
+ self.line_thickness = self._original_line_thickness * multiplier
235
+
236
+ self.control_panel.size_label.setText(f"Annotation Size: {multiplier:.1f}x")
237
+
238
+ if self.control_panel.size_slider.value() != value:
239
+ self.control_panel.size_slider.setValue(value)
103
240
 
104
- def _get_color_for_class(self, class_id, saturation, value):
241
+ self.display_all_segments()
242
+ self.clear_all_points()
243
+
244
+ def set_pan_multiplier(self, value):
245
+ self.pan_multiplier = value / 10.0
246
+ self.control_panel.pan_label.setText(f"Pan Speed: {self.pan_multiplier:.1f}x")
247
+
248
+ def set_polygon_join_threshold(self, value):
249
+ self.polygon_join_threshold = value
250
+ self.control_panel.join_label.setText(f"Polygon Join Distance: {value}px")
251
+
252
+ def handle_escape_press(self):
253
+ self.right_panel.segment_table.clearSelection()
254
+ self.right_panel.class_table.clearSelection()
255
+ self.clear_all_points()
256
+ self.viewer.setFocus()
257
+
258
+ def handle_space_press(self):
259
+ if self.mode == "polygon" and self.polygon_points:
260
+ self.finalize_polygon()
261
+ else:
262
+ self.save_current_segment()
263
+
264
+ def handle_enter_press(self):
265
+ if self.mode == "polygon" and self.polygon_points:
266
+ self.finalize_polygon()
267
+ else:
268
+ self.save_output_to_npz()
269
+
270
+ def handle_merge_press(self):
271
+ self.assign_selected_to_class()
272
+ self.right_panel.segment_table.clearSelection()
273
+
274
+ def show_notification(self, message, duration=3000):
275
+ self.control_panel.notification_label.setText(message)
276
+ QTimer.singleShot(
277
+ duration, lambda: self.control_panel.notification_label.clear()
278
+ )
279
+
280
+ def _get_color_for_class(self, class_id):
105
281
  if class_id is None:
106
282
  return QColor.fromHsv(0, 0, 128)
107
-
108
283
  hue = int((class_id * 222.4922359) % 360)
109
- color = QColor.fromHsv(hue, saturation, value)
110
-
284
+ color = QColor.fromHsv(hue, 220, 220)
111
285
  if not color.isValid():
112
286
  return QColor(Qt.GlobalColor.white)
113
287
  return color
@@ -122,9 +296,25 @@ class MainWindow(QMainWindow):
122
296
  self.previous_mode = self.mode
123
297
 
124
298
  self.mode = mode_name
125
- self.control_panel.mode_label.setText(f"Mode: {mode_name.replace('_', ' ').title()}")
299
+ self.control_panel.mode_label.setText(
300
+ f"Mode: {mode_name.replace('_', ' ').title()}"
301
+ )
126
302
  self.clear_all_points()
127
- self.viewer.setDragMode(self.viewer.DragMode.ScrollHandDrag if self.mode == "pan" else self.viewer.DragMode.NoDrag)
303
+
304
+ cursor_map = {
305
+ "sam_points": Qt.CursorShape.CrossCursor,
306
+ "polygon": Qt.CursorShape.CrossCursor,
307
+ "selection": Qt.CursorShape.ArrowCursor,
308
+ "edit": Qt.CursorShape.SizeAllCursor,
309
+ "pan": Qt.CursorShape.OpenHandCursor,
310
+ }
311
+ self.viewer.set_cursor(cursor_map.get(self.mode, Qt.CursorShape.ArrowCursor))
312
+
313
+ self.viewer.setDragMode(
314
+ self.viewer.DragMode.ScrollHandDrag
315
+ if self.mode == "pan"
316
+ else self.viewer.DragMode.NoDrag
317
+ )
128
318
 
129
319
  def set_sam_mode(self):
130
320
  self.set_mode("sam_points")
@@ -148,42 +338,92 @@ class MainWindow(QMainWindow):
148
338
 
149
339
  def toggle_edit_mode(self):
150
340
  selected_indices = self.get_selected_segment_indices()
151
- can_edit = any(self.segments[i].get("type") == "Polygon" for i in selected_indices)
341
+
152
342
  if self.mode == "edit":
153
343
  self.set_mode("selection", is_toggle=True)
154
- elif self.mode == "selection" and can_edit:
155
- self.set_mode("edit", is_toggle=True)
156
- self.display_all_segments()
344
+ return
345
+
346
+ if not selected_indices:
347
+ self.show_notification("Select a polygon to edit.")
348
+ return
349
+
350
+ can_edit = any(
351
+ self.segments[i].get("type") == "Polygon" for i in selected_indices
352
+ )
353
+
354
+ if not can_edit:
355
+ self.show_notification("Only polygon segments can be edited.")
356
+ return
357
+
358
+ self.set_mode("edit", is_toggle=True)
359
+ self.display_all_segments()
157
360
 
158
361
  def open_folder_dialog(self):
159
362
  folder_path = QFileDialog.getExistingDirectory(self, "Select Image Folder")
160
363
  if folder_path:
161
- self.right_panel.file_tree.setRootIndex(self.file_model.setRootPath(folder_path))
364
+ self.right_panel.file_tree.setRootIndex(
365
+ self.file_model.setRootPath(folder_path)
366
+ )
162
367
  self.viewer.setFocus()
163
368
 
164
369
  def load_selected_image(self, index):
165
- if not index.isValid():
370
+ if not index.isValid() or not self.file_model.isDir(index.parent()):
166
371
  return
167
372
 
168
373
  self.current_file_index = index
169
374
  path = self.file_model.filePath(index)
170
375
 
171
- if os.path.isfile(path) and path.lower().endswith((".png", ".jpg", ".jpeg", ".tiff", ".tif")):
376
+ if os.path.isfile(path) and path.lower().endswith(
377
+ (".png", ".jpg", ".jpeg", ".tiff", ".tif")
378
+ ):
172
379
  self.current_image_path = path
173
380
  pixmap = QPixmap(self.current_image_path)
174
381
  if not pixmap.isNull():
175
382
  self.reset_state()
176
383
  self.viewer.set_photo(pixmap)
177
384
  self.sam_model.set_image(self.current_image_path)
385
+ self.load_class_aliases()
178
386
  self.load_existing_mask()
387
+ self.right_panel.file_tree.setCurrentIndex(index)
179
388
  self.viewer.setFocus()
180
389
 
390
+ def load_next_image(self):
391
+ if not self.current_file_index.isValid():
392
+ return
393
+
394
+ if self.control_panel.chk_auto_save.isChecked():
395
+ self.save_output_to_npz()
396
+
397
+ row = self.current_file_index.row()
398
+ parent = self.current_file_index.parent()
399
+ if row + 1 < self.file_model.rowCount(parent):
400
+ next_index = self.file_model.index(row + 1, 0, parent)
401
+ self.load_selected_image(next_index)
402
+
403
+ def load_previous_image(self):
404
+ if not self.current_file_index.isValid():
405
+ return
406
+
407
+ if self.control_panel.chk_auto_save.isChecked():
408
+ self.save_output_to_npz()
409
+
410
+ row = self.current_file_index.row()
411
+ parent = self.current_file_index.parent()
412
+ if row > 0:
413
+ prev_index = self.file_model.index(row - 1, 0, parent)
414
+ self.load_selected_image(prev_index)
415
+
181
416
  def reset_state(self):
182
417
  self.clear_all_points()
183
418
  self.segments.clear()
419
+ self.class_aliases.clear()
184
420
  self.next_class_id = 0
185
421
  self.update_all_lists()
186
- items_to_remove = [item for item in self.viewer.scene().items() if item is not self.viewer._pixmap_item]
422
+ items_to_remove = [
423
+ item
424
+ for item in self.viewer.scene().items()
425
+ if item is not self.viewer._pixmap_item
426
+ ]
187
427
  for item in items_to_remove:
188
428
  self.viewer.scene().removeItem(item)
189
429
  self.segment_items.clear()
@@ -191,48 +431,69 @@ class MainWindow(QMainWindow):
191
431
 
192
432
  def keyPressEvent(self, event):
193
433
  key, mods = event.key(), event.modifiers()
194
- if event.isAutoRepeat():
434
+
435
+ if event.isAutoRepeat() and key not in {
436
+ Qt.Key.Key_W,
437
+ Qt.Key.Key_A,
438
+ Qt.Key.Key_S,
439
+ Qt.Key.Key_D,
440
+ }:
195
441
  return
442
+
443
+ shift_multiplier = 5.0 if mods & Qt.KeyboardModifier.ShiftModifier else 1.0
444
+
196
445
  if key == Qt.Key.Key_W:
197
- self.viewer.verticalScrollBar().setValue(self.viewer.verticalScrollBar().value() - int(self.viewer.height() * 0.1))
198
- elif key == Qt.Key.Key_S and not mods:
199
- self.viewer.verticalScrollBar().setValue(self.viewer.verticalScrollBar().value() + int(self.viewer.height() * 0.1))
200
- elif key == Qt.Key.Key_A and not (mods & Qt.KeyboardModifier.ControlModifier):
201
- self.viewer.horizontalScrollBar().setValue(self.viewer.horizontalScrollBar().value() - int(self.viewer.width() * 0.1))
446
+ amount = int(
447
+ self.viewer.height() * 0.1 * self.pan_multiplier * shift_multiplier
448
+ )
449
+ self.viewer.verticalScrollBar().setValue(
450
+ self.viewer.verticalScrollBar().value() - amount
451
+ )
452
+ elif key == Qt.Key.Key_S:
453
+ amount = int(
454
+ self.viewer.height() * 0.1 * self.pan_multiplier * shift_multiplier
455
+ )
456
+ self.viewer.verticalScrollBar().setValue(
457
+ self.viewer.verticalScrollBar().value() + amount
458
+ )
459
+ elif key == Qt.Key.Key_A:
460
+ amount = int(
461
+ self.viewer.width() * 0.1 * self.pan_multiplier * shift_multiplier
462
+ )
463
+ self.viewer.horizontalScrollBar().setValue(
464
+ self.viewer.horizontalScrollBar().value() - amount
465
+ )
202
466
  elif key == Qt.Key.Key_D:
203
- self.viewer.horizontalScrollBar().setValue(self.viewer.horizontalScrollBar().value() + int(self.viewer.width() * 0.1))
204
- elif key == Qt.Key.Key_1:
205
- self.set_sam_mode()
206
- elif key == Qt.Key.Key_2:
207
- self.set_polygon_mode()
208
- elif key == Qt.Key.Key_E:
209
- self.toggle_selection_mode()
210
- elif key == Qt.Key.Key_Q:
211
- self.toggle_pan_mode()
212
- elif key == Qt.Key.Key_R:
213
- self.toggle_edit_mode()
214
- elif key == Qt.Key.Key_C or key == Qt.Key.Key_Escape:
215
- self.clear_all_points()
216
- elif key == Qt.Key.Key_V or key == Qt.Key.Key_Backspace:
217
- self.delete_selected_segments()
218
- elif key == Qt.Key.Key_M:
219
- self.assign_selected_to_class()
220
- self.right_panel.segment_table.clearSelection()
221
- elif key == Qt.Key.Key_Z and mods == Qt.KeyboardModifier.ControlModifier:
222
- self.undo_last_action()
223
- elif key == Qt.Key.Key_A and mods == Qt.KeyboardModifier.ControlModifier:
224
- self.right_panel.segment_table.selectAll()
225
- elif key == Qt.Key.Key_Space:
226
- self.save_current_segment()
227
- elif key == Qt.Key.Key_Return or key == Qt.Key.Key_Enter:
228
- self.save_output_to_npz()
467
+ amount = int(
468
+ self.viewer.width() * 0.1 * self.pan_multiplier * shift_multiplier
469
+ )
470
+ self.viewer.horizontalScrollBar().setValue(
471
+ self.viewer.horizontalScrollBar().value() + amount
472
+ )
473
+ elif (
474
+ key == Qt.Key.Key_Equal or key == Qt.Key.Key_Plus
475
+ ) and mods == Qt.KeyboardModifier.ControlModifier:
476
+ current_val = self.control_panel.size_slider.value()
477
+ self.control_panel.size_slider.setValue(current_val + 1)
478
+ elif key == Qt.Key.Key_Minus and mods == Qt.KeyboardModifier.ControlModifier:
479
+ current_val = self.control_panel.size_slider.value()
480
+ self.control_panel.size_slider.setValue(current_val - 1)
481
+ else:
482
+ super().keyPressEvent(event)
229
483
 
230
484
  def scene_mouse_press(self, event):
231
485
  self._original_mouse_press(event)
232
486
  if event.isAccepted():
233
487
  return
488
+
489
+ if self.mode == "pan":
490
+ self.viewer.set_cursor(Qt.CursorShape.ClosedHandCursor)
491
+
234
492
  pos = event.scenePos()
235
- if self.viewer._pixmap_item.pixmap().isNull() or not self.viewer._pixmap_item.pixmap().rect().contains(pos.toPoint()):
493
+ if (
494
+ self.viewer._pixmap_item.pixmap().isNull()
495
+ or not self.viewer._pixmap_item.pixmap().rect().contains(pos.toPoint())
496
+ ):
236
497
  return
237
498
  if self.mode == "sam_points":
238
499
  if event.button() == Qt.MouseButton.LeftButton:
@@ -252,7 +513,9 @@ class MainWindow(QMainWindow):
252
513
  self.is_dragging_polygon = True
253
514
  selected_indices = self.get_selected_segment_indices()
254
515
  self.drag_initial_vertices = {
255
- i: list(self.segments[i]["vertices"]) for i in selected_indices if self.segments[i].get("type") == "Polygon"
516
+ i: list(self.segments[i]["vertices"])
517
+ for i in selected_indices
518
+ if self.segments[i].get("type") == "Polygon"
256
519
  }
257
520
 
258
521
  def scene_mouse_move(self, event):
@@ -260,12 +523,18 @@ class MainWindow(QMainWindow):
260
523
  if self.mode == "edit" and self.is_dragging_polygon:
261
524
  delta = pos - self.drag_start_pos
262
525
  for i, initial_verts in self.drag_initial_vertices.items():
263
- self.segments[i]["vertices"] = [QPointF(v.x() + delta.x(), v.y() + delta.y()) for v in initial_verts]
526
+ self.segments[i]["vertices"] = [
527
+ QPointF(v.x() + delta.x(), v.y() + delta.y()) for v in initial_verts
528
+ ]
264
529
  self.update_polygon_visuals(i)
265
530
  elif self.mode == "polygon" and self.polygon_points:
266
531
  if self.rubber_band_line is None:
267
532
  self.rubber_band_line = QGraphicsLineItem()
268
- self.rubber_band_line.setPen(QPen(Qt.GlobalColor.white, 2, Qt.PenStyle.DotLine))
533
+ line_color = QColor(Qt.GlobalColor.white)
534
+ line_color.setAlpha(150)
535
+ self.rubber_band_line.setPen(
536
+ QPen(line_color, self.line_thickness, Qt.PenStyle.DotLine)
537
+ )
269
538
  self.viewer.scene().addItem(self.rubber_band_line)
270
539
  self.rubber_band_line.setLine(
271
540
  self.polygon_points[-1].x(),
@@ -278,6 +547,8 @@ class MainWindow(QMainWindow):
278
547
  self._original_mouse_move(event)
279
548
 
280
549
  def scene_mouse_release(self, event):
550
+ if self.mode == "pan":
551
+ self.viewer.set_cursor(Qt.CursorShape.OpenHandCursor)
281
552
  if self.mode == "edit" and self.is_dragging_polygon:
282
553
  self.is_dragging_polygon = False
283
554
  self.drag_initial_vertices.clear()
@@ -286,12 +557,30 @@ class MainWindow(QMainWindow):
286
557
  def undo_last_action(self):
287
558
  if self.mode == "polygon" and self.polygon_points:
288
559
  self.polygon_points.pop()
289
- if self.polygon_preview_items:
290
- self.viewer.scene().removeItem(self.polygon_preview_items.pop())
560
+ for item in self.polygon_preview_items:
561
+ if item.scene():
562
+ self.viewer.scene().removeItem(item)
563
+ self.polygon_preview_items.clear()
564
+ for point in self.polygon_points:
565
+ point_diameter = self.point_radius * 2
566
+ point_color = QColor(Qt.GlobalColor.blue)
567
+ point_color.setAlpha(150)
568
+ dot = QGraphicsEllipseItem(
569
+ point.x() - self.point_radius,
570
+ point.y() - self.point_radius,
571
+ point_diameter,
572
+ point_diameter,
573
+ )
574
+ dot.setBrush(QBrush(point_color))
575
+ dot.setPen(QPen(Qt.GlobalColor.transparent))
576
+ self.viewer.scene().addItem(dot)
577
+ self.polygon_preview_items.append(dot)
291
578
  self.draw_polygon_preview()
292
579
  elif self.mode == "sam_points" and self.point_items:
293
580
  item_to_remove = self.point_items.pop()
294
- point_pos = item_to_remove.rect().topLeft() + QPointF(4, 4)
581
+ point_pos = item_to_remove.rect().topLeft() + QPointF(
582
+ self.point_radius, self.point_radius
583
+ )
295
584
  point_coords = [int(point_pos.x()), int(point_pos.y())]
296
585
  if point_coords in self.positive_points:
297
586
  self.positive_points.remove(point_coords)
@@ -300,6 +589,17 @@ class MainWindow(QMainWindow):
300
589
  self.viewer.scene().removeItem(item_to_remove)
301
590
  self.update_segmentation()
302
591
 
592
+ def _update_next_class_id(self):
593
+ all_ids = {
594
+ seg.get("class_id")
595
+ for seg in self.segments
596
+ if seg.get("class_id") is not None
597
+ }
598
+ if not all_ids:
599
+ self.next_class_id = 0
600
+ else:
601
+ self.next_class_id = max(all_ids) + 1
602
+
303
603
  def finalize_polygon(self):
304
604
  if len(self.polygon_points) < 3:
305
605
  return
@@ -314,7 +614,7 @@ class MainWindow(QMainWindow):
314
614
  "class_id": self.next_class_id,
315
615
  }
316
616
  )
317
- self.next_class_id += 1
617
+ self._update_next_class_id()
318
618
  self.polygon_points.clear()
319
619
  for item in self.polygon_preview_items:
320
620
  self.viewer.scene().removeItem(item)
@@ -325,14 +625,25 @@ class MainWindow(QMainWindow):
325
625
  x, y = int(pos.x()), int(pos.y())
326
626
  for i in range(len(self.segments) - 1, -1, -1):
327
627
  seg = self.segments[i]
328
- mask = self.rasterize_polygon(seg["vertices"]) if seg["type"] == "Polygon" else seg.get("mask")
329
- if mask is not None and y < mask.shape[0] and x < mask.shape[1] and mask[y, x]:
628
+ mask = (
629
+ self.rasterize_polygon(seg["vertices"])
630
+ if seg["type"] == "Polygon"
631
+ else seg.get("mask")
632
+ )
633
+ if (
634
+ mask is not None
635
+ and y < mask.shape[0]
636
+ and x < mask.shape[1]
637
+ and mask[y, x]
638
+ ):
330
639
  for j in range(self.right_panel.segment_table.rowCount()):
331
640
  item = self.right_panel.segment_table.item(j, 0)
332
641
  if item and item.data(Qt.ItemDataRole.UserRole) == i:
333
642
  table = self.right_panel.segment_table
334
643
  is_selected = table.item(j, 0).isSelected()
335
- range_to_select = QTableWidgetSelectionRange(j, 0, j, table.columnCount() - 1)
644
+ range_to_select = QTableWidgetSelectionRange(
645
+ j, 0, j, table.columnCount() - 1
646
+ )
336
647
  table.setRangeSelected(range_to_select, not is_selected)
337
648
  return
338
649
  self.viewer.setFocus()
@@ -342,18 +653,22 @@ class MainWindow(QMainWindow):
342
653
  if not selected_indices:
343
654
  return
344
655
 
345
- existing_class_ids = [self.segments[i]["class_id"] for i in selected_indices if self.segments[i].get("class_id") is not None]
656
+ existing_class_ids = [
657
+ self.segments[i]["class_id"]
658
+ for i in selected_indices
659
+ if self.segments[i].get("class_id") is not None
660
+ ]
346
661
 
347
662
  if existing_class_ids:
348
663
  target_class_id = min(existing_class_ids)
349
664
  else:
350
- target_class_id = self.segments[selected_indices[0]].get("class_id")
665
+ target_class_id = self.next_class_id
351
666
 
352
667
  for i in selected_indices:
353
668
  self.segments[i]["class_id"] = target_class_id
354
669
 
670
+ self._update_next_class_id()
355
671
  self.update_all_lists()
356
- self.right_panel.segment_table.clearSelection()
357
672
  self.viewer.setFocus()
358
673
 
359
674
  def rasterize_polygon(self, vertices):
@@ -378,31 +693,59 @@ class MainWindow(QMainWindow):
378
693
  for i, seg_dict in enumerate(self.segments):
379
694
  self.segment_items[i] = []
380
695
  class_id = seg_dict.get("class_id")
381
- base_color = self._get_color_for_class(class_id, saturation=220, value=220)
696
+ base_color = self._get_color_for_class(class_id)
382
697
 
383
698
  if seg_dict["type"] == "Polygon":
384
699
  poly_item = HoverablePolygonItem(QPolygonF(seg_dict["vertices"]))
385
- default_brush = QBrush(QColor(base_color.red(), base_color.green(), base_color.blue(), 70))
386
- hover_brush = QBrush(QColor(base_color.red(), base_color.green(), base_color.blue(), 170))
700
+ default_brush = QBrush(
701
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 70)
702
+ )
703
+ hover_brush = QBrush(
704
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 170)
705
+ )
387
706
  poly_item.set_brushes(default_brush, hover_brush)
388
707
  poly_item.setPen(QPen(Qt.GlobalColor.transparent))
389
708
  self.viewer.scene().addItem(poly_item)
390
709
  self.segment_items[i].append(poly_item)
710
+ base_color.setAlpha(150)
391
711
  vertex_color = QBrush(base_color)
712
+ point_diameter = self.point_radius * 2
392
713
  for v in seg_dict["vertices"]:
393
- dot = QGraphicsEllipseItem(v.x() - 3, v.y() - 3, 6, 6)
714
+ dot = QGraphicsEllipseItem(
715
+ v.x() - self.point_radius,
716
+ v.y() - self.point_radius,
717
+ point_diameter,
718
+ point_diameter,
719
+ )
394
720
  dot.setBrush(vertex_color)
721
+ dot.setPen(QPen(Qt.GlobalColor.transparent))
395
722
  self.viewer.scene().addItem(dot)
396
723
  self.segment_items[i].append(dot)
397
724
  if self.mode == "edit" and i in selected_indices:
725
+ handle_diameter = self.point_radius * 2
398
726
  for idx, v in enumerate(seg_dict["vertices"]):
399
- vertex_item = EditableVertexItem(self, i, idx, -4, -4, 8, 8)
727
+ vertex_item = EditableVertexItem(
728
+ self,
729
+ i,
730
+ idx,
731
+ -handle_diameter / 2,
732
+ -handle_diameter / 2,
733
+ handle_diameter,
734
+ handle_diameter,
735
+ )
400
736
  vertex_item.setPos(v)
401
737
  self.viewer.scene().addItem(vertex_item)
402
738
  self.segment_items[i].append(vertex_item)
403
739
  elif seg_dict.get("mask") is not None:
404
- pixmap = mask_to_pixmap(seg_dict["mask"], base_color.getRgb()[:3])
405
- pixmap_item = self.viewer.scene().addPixmap(pixmap)
740
+ default_pixmap = mask_to_pixmap(
741
+ seg_dict["mask"], base_color.getRgb()[:3], alpha=70
742
+ )
743
+ hover_pixmap = mask_to_pixmap(
744
+ seg_dict["mask"], base_color.getRgb()[:3], alpha=170
745
+ )
746
+ pixmap_item = HoverablePixmapItem()
747
+ pixmap_item.set_pixmaps(default_pixmap, hover_pixmap)
748
+ self.viewer.scene().addItem(pixmap_item)
406
749
  pixmap_item.setZValue(i + 1)
407
750
  self.segment_items[i].append(pixmap_item)
408
751
  self.highlight_selected_segments()
@@ -426,7 +769,11 @@ class MainWindow(QMainWindow):
426
769
  selected_indices = self.get_selected_segment_indices()
427
770
  for i in selected_indices:
428
771
  seg = self.segments[i]
429
- mask = self.rasterize_polygon(seg["vertices"]) if seg["type"] == "Polygon" else seg.get("mask")
772
+ mask = (
773
+ self.rasterize_polygon(seg["vertices"])
774
+ if seg["type"] == "Polygon"
775
+ else seg.get("mask")
776
+ )
430
777
  if mask is not None:
431
778
  pixmap = mask_to_pixmap(mask, (255, 255, 255))
432
779
  highlight_item = self.viewer.scene().addPixmap(pixmap)
@@ -434,9 +781,9 @@ class MainWindow(QMainWindow):
434
781
  self.highlight_items.append(highlight_item)
435
782
 
436
783
  def update_all_lists(self):
784
+ self.update_class_list()
437
785
  self.update_class_filter_combo()
438
786
  self.update_segment_table()
439
- self.update_class_list()
440
787
  self.display_all_segments()
441
788
 
442
789
  def update_segment_table(self):
@@ -450,7 +797,7 @@ class MainWindow(QMainWindow):
450
797
  filter_class_id = -1
451
798
  if not show_all:
452
799
  try:
453
- filter_class_id = int(filter_text.split(" ")[1])
800
+ filter_class_id = int(filter_text.split("(ID: ")[1][:-1])
454
801
  except (ValueError, IndexError):
455
802
  pass
456
803
 
@@ -463,22 +810,27 @@ class MainWindow(QMainWindow):
463
810
 
464
811
  for row, (original_index, seg) in enumerate(display_segments):
465
812
  class_id = seg.get("class_id")
466
- color = self._get_color_for_class(class_id, saturation=180, value=200)
467
-
813
+ color = self._get_color_for_class(class_id)
468
814
  class_id_str = str(class_id) if class_id is not None else "N/A"
815
+
816
+ alias_str = "N/A"
817
+ if class_id is not None:
818
+ alias_str = self.class_aliases.get(class_id, str(class_id))
819
+ alias_item = QTableWidgetItem(alias_str)
820
+
469
821
  index_item = NumericTableWidgetItem(str(original_index + 1))
470
822
  class_item = NumericTableWidgetItem(class_id_str)
471
- type_item = QTableWidgetItem(seg.get("type", "N/A"))
472
823
 
473
824
  index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
474
- type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
825
+ class_item.setFlags(class_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
826
+ alias_item.setFlags(alias_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
475
827
  index_item.setData(Qt.ItemDataRole.UserRole, original_index)
476
828
 
477
829
  table.setItem(row, 0, index_item)
478
830
  table.setItem(row, 1, class_item)
479
- table.setItem(row, 2, type_item)
831
+ table.setItem(row, 2, alias_item)
480
832
 
481
- for col in range(3):
833
+ for col in range(table.columnCount()):
482
834
  if table.item(row, col):
483
835
  table.item(row, col).setBackground(QBrush(color))
484
836
 
@@ -495,30 +847,59 @@ class MainWindow(QMainWindow):
495
847
  def update_class_list(self):
496
848
  class_table = self.right_panel.class_table
497
849
  class_table.blockSignals(True)
498
- class_table.clearContents()
499
850
 
500
- unique_class_ids = sorted(list({seg.get("class_id") for seg in self.segments if seg.get("class_id") is not None}))
501
- class_table.setRowCount(len(unique_class_ids))
851
+ preserved_aliases = self.class_aliases.copy()
852
+ unique_class_ids = sorted(
853
+ list(
854
+ {
855
+ seg.get("class_id")
856
+ for seg in self.segments
857
+ if seg.get("class_id") is not None
858
+ }
859
+ )
860
+ )
502
861
 
503
- for row, cid in enumerate(unique_class_ids):
504
- item = QTableWidgetItem(str(cid))
505
- item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
862
+ new_aliases = {}
863
+ for cid in unique_class_ids:
864
+ new_aliases[cid] = preserved_aliases.get(cid, str(cid))
506
865
 
507
- color = self._get_color_for_class(cid, saturation=180, value=200)
866
+ self.class_aliases = new_aliases
508
867
 
509
- item.setBackground(QBrush(color))
510
- class_table.setItem(row, 0, item)
868
+ class_table.clearContents()
869
+ class_table.setRowCount(len(unique_class_ids))
870
+ for row, cid in enumerate(unique_class_ids):
871
+ alias_item = QTableWidgetItem(self.class_aliases.get(cid))
872
+ id_item = QTableWidgetItem(str(cid))
873
+ id_item.setFlags(id_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
874
+ color = self._get_color_for_class(cid)
875
+ alias_item.setBackground(QBrush(color))
876
+ id_item.setBackground(QBrush(color))
877
+ class_table.setItem(row, 0, alias_item)
878
+ class_table.setItem(row, 1, id_item)
511
879
 
512
880
  class_table.blockSignals(False)
513
881
 
514
882
  def update_class_filter_combo(self):
515
883
  combo = self.right_panel.class_filter_combo
516
- unique_class_ids = sorted(list({seg.get("class_id") for seg in self.segments if seg.get("class_id") is not None}))
884
+ unique_class_ids = sorted(
885
+ list(
886
+ {
887
+ seg.get("class_id")
888
+ for seg in self.segments
889
+ if seg.get("class_id") is not None
890
+ }
891
+ )
892
+ )
517
893
  current_selection = combo.currentText()
518
894
  combo.blockSignals(True)
519
895
  combo.clear()
520
896
  combo.addItem("All Classes")
521
- combo.addItems([f"Class {cid}" for cid in unique_class_ids])
897
+ combo.addItems(
898
+ [
899
+ f"{self.class_aliases.get(cid, cid)} (ID: {cid})"
900
+ for cid in unique_class_ids
901
+ ]
902
+ )
522
903
  if combo.findText(current_selection) > -1:
523
904
  combo.setCurrentText(current_selection)
524
905
  else:
@@ -527,123 +908,161 @@ class MainWindow(QMainWindow):
527
908
 
528
909
  def reassign_class_ids(self):
529
910
  class_table = self.right_panel.class_table
530
- ordered_ids = [
531
- int(class_table.item(row, 0).text()) for row in range(class_table.rowCount()) if class_table.item(row, 0) is not None
532
- ]
911
+ ordered_ids = []
912
+ for row in range(class_table.rowCount()):
913
+ id_item = class_table.item(row, 1)
914
+ if id_item and id_item.text():
915
+ try:
916
+ ordered_ids.append(int(id_item.text()))
917
+ except ValueError:
918
+ continue
533
919
  id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
534
920
  for seg in self.segments:
535
921
  old_id = seg.get("class_id")
536
922
  if old_id in id_map:
537
923
  seg["class_id"] = id_map[old_id]
538
- self.next_class_id = len(ordered_ids)
924
+ new_aliases = {
925
+ id_map[old_id]: self.class_aliases.get(old_id, str(old_id))
926
+ for old_id in ordered_ids
927
+ if old_id in self.class_aliases
928
+ }
929
+ self.class_aliases = new_aliases
930
+ self._update_next_class_id()
539
931
  self.update_all_lists()
540
932
  self.viewer.setFocus()
541
933
 
542
- def handle_class_id_change(self, item):
543
- if item.column() != 1:
544
- return
545
- table = self.right_panel.segment_table
546
- index_item = table.item(item.row(), 0)
547
- if not index_item:
934
+ def handle_alias_change(self, item):
935
+ if item.column() != 0:
548
936
  return
937
+ class_table = self.right_panel.class_table
938
+ class_table.blockSignals(True)
939
+ id_item = class_table.item(item.row(), 1)
940
+ if id_item:
941
+ try:
942
+ class_id = int(id_item.text())
943
+ self.class_aliases[class_id] = item.text()
944
+ except (ValueError, AttributeError):
945
+ pass
946
+ class_table.blockSignals(False)
549
947
 
550
- table.blockSignals(True)
551
- try:
552
- new_class_id_text = item.text()
553
- if not new_class_id_text.strip():
554
- raise ValueError("Class ID cannot be empty.")
555
- new_class_id = int(new_class_id_text)
556
- original_index = index_item.data(Qt.ItemDataRole.UserRole)
557
-
558
- if original_index is None or original_index >= len(self.segments):
559
- raise IndexError("Invalid segment index found in table.")
560
-
561
- self.segments[original_index]["class_id"] = new_class_id
562
- if new_class_id >= self.next_class_id:
563
- self.next_class_id = new_class_id + 1
564
- self.update_all_lists()
565
- except (ValueError, TypeError, AttributeError, IndexError) as e:
566
- original_index = index_item.data(Qt.ItemDataRole.UserRole)
567
- if original_index is not None and original_index < len(self.segments):
568
- original_class_id = self.segments[original_index].get("class_id")
569
- item.setText(str(original_class_id) if original_class_id is not None else "N/A")
570
- finally:
571
- table.blockSignals(False)
572
- self.viewer.setFocus()
948
+ self.update_class_filter_combo()
949
+ self.update_segment_table()
573
950
 
574
951
  def get_selected_segment_indices(self):
575
952
  table = self.right_panel.segment_table
576
953
  selected_items = table.selectedItems()
577
954
  selected_rows = sorted(list({item.row() for item in selected_items}))
578
- return [table.item(row, 0).data(Qt.ItemDataRole.UserRole) for row in selected_rows if table.item(row, 0)]
955
+ return [
956
+ table.item(row, 0).data(Qt.ItemDataRole.UserRole)
957
+ for row in selected_rows
958
+ if table.item(row, 0)
959
+ ]
579
960
 
580
961
  def save_output_to_npz(self):
581
- if not self.segments or not self.current_image_path:
582
- return
583
- self.right_panel.status_label.setText("Saving...")
584
- QApplication.processEvents()
962
+ save_npz = self.control_panel.chk_save_npz.isChecked()
963
+ save_txt = self.control_panel.chk_save_txt.isChecked()
964
+ save_aliases = self.control_panel.chk_save_class_aliases.isChecked()
585
965
 
586
- output_path = os.path.splitext(self.current_image_path)[0] + ".npz"
587
- h, w = (
588
- self.viewer._pixmap_item.pixmap().height(),
589
- self.viewer._pixmap_item.pixmap().width(),
590
- )
591
- unique_class_ids = sorted(list({seg["class_id"] for seg in self.segments if seg.get("class_id") is not None}))
592
- if not unique_class_ids:
593
- self.right_panel.status_label.setText("Save failed: No classes.")
594
- QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
966
+ if not self.current_image_path or not any([save_npz, save_txt, save_aliases]):
595
967
  return
596
968
 
597
- id_map = {old_id: new_id for new_id, old_id in enumerate(unique_class_ids)}
598
- num_final_classes = len(unique_class_ids)
599
- final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
969
+ self.right_panel.status_label.setText("Saving...")
970
+ QApplication.processEvents()
600
971
 
601
- for seg in self.segments:
602
- class_id = seg.get("class_id")
603
- if class_id not in id_map:
604
- continue
605
- new_channel_idx = id_map[class_id]
606
- mask = self.rasterize_polygon(seg["vertices"]) if seg["type"] == "Polygon" else seg.get("mask")
607
- if mask is not None:
608
- final_mask_tensor[:, :, new_channel_idx] = np.logical_or(final_mask_tensor[:, :, new_channel_idx], mask)
972
+ saved_something = False
609
973
 
610
- np.savez_compressed(output_path, mask=final_mask_tensor.astype(np.uint8))
611
- self.file_model.setRootPath(self.file_model.rootPath())
974
+ if save_npz or save_txt:
975
+ if not self.segments:
976
+ self.show_notification("No segments to save.")
977
+ else:
978
+ h, w = (
979
+ self.viewer._pixmap_item.pixmap().height(),
980
+ self.viewer._pixmap_item.pixmap().width(),
981
+ )
982
+ class_table = self.right_panel.class_table
983
+ ordered_ids = [
984
+ int(class_table.item(row, 1).text())
985
+ for row in range(class_table.rowCount())
986
+ if class_table.item(row, 1) is not None
987
+ ]
988
+
989
+ if not ordered_ids:
990
+ self.show_notification("No classes defined for mask saving.")
991
+ else:
992
+ id_map = {
993
+ old_id: new_id for new_id, old_id in enumerate(ordered_ids)
994
+ }
995
+ num_final_classes = len(ordered_ids)
996
+ final_mask_tensor = np.zeros(
997
+ (h, w, num_final_classes), dtype=np.uint8
998
+ )
999
+
1000
+ for seg in self.segments:
1001
+ class_id = seg.get("class_id")
1002
+ if class_id not in id_map:
1003
+ continue
1004
+ new_channel_idx = id_map[class_id]
1005
+ mask = (
1006
+ self.rasterize_polygon(seg["vertices"])
1007
+ if seg["type"] == "Polygon"
1008
+ else seg.get("mask")
1009
+ )
1010
+ if mask is not None:
1011
+ final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
1012
+ final_mask_tensor[:, :, new_channel_idx], mask
1013
+ )
1014
+ if save_npz:
1015
+ npz_path = os.path.splitext(self.current_image_path)[0] + ".npz"
1016
+ np.savez_compressed(
1017
+ npz_path, mask=final_mask_tensor.astype(np.uint8)
1018
+ )
1019
+ self.file_model.set_highlighted_path(npz_path)
1020
+ QTimer.singleShot(
1021
+ 1500, lambda: self.file_model.set_highlighted_path(None)
1022
+ )
1023
+ saved_something = True
1024
+ if save_txt:
1025
+ self.generate_yolo_annotations(final_mask_tensor)
1026
+ saved_something = True
1027
+
1028
+ if save_aliases:
1029
+ aliases_path = os.path.splitext(self.current_image_path)[0] + ".json"
1030
+ aliases_to_save = {str(k): v for k, v in self.class_aliases.items()}
1031
+ with open(aliases_path, "w") as f:
1032
+ json.dump(aliases_to_save, f, indent=4)
1033
+ saved_something = True
1034
+
1035
+ if saved_something:
1036
+ self.right_panel.status_label.setText("Saved!")
1037
+ else:
1038
+ self.right_panel.status_label.clear()
612
1039
 
613
- self.right_panel.status_label.setText("Saved!")
614
- self.generate_yolo_annotations(npz_file_path=output_path)
615
1040
  QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
616
1041
 
617
- def generate_yolo_annotations(self, npz_file_path):
1042
+ def generate_yolo_annotations(self, mask_tensor):
618
1043
  output_path = os.path.splitext(self.current_image_path)[0] + ".txt"
619
- npz_data = np.load(npz_file_path) # Load the saved npz file
620
-
621
- img = npz_data["mask"][:, :, :]
622
- num_channels = img.shape[2] # C
623
- h, w = img.shape[:2] # H, W
1044
+ h, w, num_channels = mask_tensor.shape
624
1045
 
625
1046
  directory_path = os.path.dirname(output_path)
626
1047
  os.makedirs(directory_path, exist_ok=True)
627
1048
 
628
1049
  yolo_annotations = []
629
-
630
1050
  for channel in range(num_channels):
631
- single_channel_image = img[:, :, channel]
632
- contours, _ = cv2.findContours(single_channel_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
633
-
634
- class_id = channel # Use the channel index as the class ID
1051
+ single_channel_image = mask_tensor[:, :, channel]
1052
+ if not np.any(single_channel_image):
1053
+ continue
635
1054
 
1055
+ contours, _ = cv2.findContours(
1056
+ single_channel_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
1057
+ )
1058
+ class_id = channel
636
1059
  for contour in contours:
637
1060
  x, y, width, height = cv2.boundingRect(contour)
638
- center_x = x + width / 2
639
- center_y = y + height / 2
640
-
641
- normalized_center_x = center_x / w
642
- normalized_center_y = center_y / h
1061
+ center_x = (x + width / 2) / w
1062
+ center_y = (y + height / 2) / h
643
1063
  normalized_width = width / w
644
1064
  normalized_height = height / h
645
-
646
- yolo_entry = f"{class_id} {normalized_center_x} {normalized_center_y} {normalized_width} {normalized_height}"
1065
+ yolo_entry = f"{class_id} {center_x} {center_y} {normalized_width} {normalized_height}"
647
1066
  yolo_annotations.append(yolo_entry)
648
1067
 
649
1068
  with open(output_path, "w") as file:
@@ -651,7 +1070,11 @@ class MainWindow(QMainWindow):
651
1070
  file.write(annotation + "\n")
652
1071
 
653
1072
  def save_current_segment(self):
654
- if self.mode != "sam_points" or not hasattr(self, "preview_mask_item") or not self.preview_mask_item:
1073
+ if (
1074
+ self.mode != "sam_points"
1075
+ or not hasattr(self, "preview_mask_item")
1076
+ or not self.preview_mask_item
1077
+ ):
655
1078
  return
656
1079
  mask = self.sam_model.predict(self.positive_points, self.negative_points)
657
1080
  if mask is not None:
@@ -663,7 +1086,7 @@ class MainWindow(QMainWindow):
663
1086
  "class_id": self.next_class_id,
664
1087
  }
665
1088
  )
666
- self.next_class_id += 1
1089
+ self._update_next_class_id()
667
1090
  self.clear_all_points()
668
1091
  self.update_all_lists()
669
1092
 
@@ -673,9 +1096,24 @@ class MainWindow(QMainWindow):
673
1096
  return
674
1097
  for i in sorted(selected_indices, reverse=True):
675
1098
  del self.segments[i]
1099
+ self._update_next_class_id()
676
1100
  self.update_all_lists()
677
1101
  self.viewer.setFocus()
678
1102
 
1103
+ def load_class_aliases(self):
1104
+ if not self.current_image_path:
1105
+ return
1106
+ json_path = os.path.splitext(self.current_image_path)[0] + ".json"
1107
+ if os.path.exists(json_path):
1108
+ try:
1109
+ with open(json_path, "r") as f:
1110
+ loaded_aliases = json.load(f)
1111
+ # JSON loads keys as strings, convert them to int
1112
+ self.class_aliases = {int(k): v for k, v in loaded_aliases.items()}
1113
+ except (json.JSONDecodeError, ValueError) as e:
1114
+ print(f"Error loading class aliases from {json_path}: {e}")
1115
+ self.class_aliases.clear()
1116
+
679
1117
  def load_existing_mask(self):
680
1118
  if not self.current_image_path:
681
1119
  return
@@ -698,16 +1136,25 @@ class MainWindow(QMainWindow):
698
1136
  "class_id": i,
699
1137
  }
700
1138
  )
701
- self.next_class_id = num_classes
1139
+ self._update_next_class_id()
702
1140
  self.update_all_lists()
703
1141
 
704
1142
  def add_point(self, pos, positive):
705
1143
  point_list = self.positive_points if positive else self.negative_points
706
1144
  point_list.append([int(pos.x()), int(pos.y())])
707
- color = Qt.GlobalColor.green if positive else Qt.GlobalColor.red
708
- point_item = QGraphicsEllipseItem(pos.x() - 4, pos.y() - 4, 8, 8)
709
- point_item.setBrush(QBrush(color))
710
- point_item.setPen(QPen(Qt.GlobalColor.white))
1145
+ point_color = (
1146
+ QColor(Qt.GlobalColor.green) if positive else QColor(Qt.GlobalColor.red)
1147
+ )
1148
+ point_color.setAlpha(150)
1149
+ point_diameter = self.point_radius * 2
1150
+ point_item = QGraphicsEllipseItem(
1151
+ pos.x() - self.point_radius,
1152
+ pos.y() - self.point_radius,
1153
+ point_diameter,
1154
+ point_diameter,
1155
+ )
1156
+ point_item.setBrush(QBrush(point_color))
1157
+ point_item.setPen(QPen(Qt.GlobalColor.transparent))
711
1158
  self.viewer.scene().addItem(point_item)
712
1159
  self.point_items.append(point_item)
713
1160
 
@@ -740,27 +1187,42 @@ class MainWindow(QMainWindow):
740
1187
  self.preview_mask_item = None
741
1188
 
742
1189
  def handle_polygon_click(self, pos):
743
- if self.polygon_points and (((pos.x() - self.polygon_points[0].x()) ** 2 + (pos.y() - self.polygon_points[0].y()) ** 2) < 25):
1190
+ if self.polygon_points and (
1191
+ (
1192
+ (pos.x() - self.polygon_points[0].x()) ** 2
1193
+ + (pos.y() - self.polygon_points[0].y()) ** 2
1194
+ )
1195
+ < self.polygon_join_threshold**2
1196
+ ):
744
1197
  if len(self.polygon_points) > 2:
745
1198
  self.finalize_polygon()
746
1199
  return
747
1200
  self.polygon_points.append(pos)
748
- dot = QGraphicsEllipseItem(pos.x() - 2, pos.y() - 2, 4, 4)
749
- dot.setBrush(QBrush(Qt.GlobalColor.blue))
750
- dot.setPen(QPen(Qt.GlobalColor.cyan))
1201
+ point_diameter = self.point_radius * 2
1202
+ point_color = QColor(Qt.GlobalColor.blue)
1203
+ point_color.setAlpha(150)
1204
+ dot = QGraphicsEllipseItem(
1205
+ pos.x() - self.point_radius,
1206
+ pos.y() - self.point_radius,
1207
+ point_diameter,
1208
+ point_diameter,
1209
+ )
1210
+ dot.setBrush(QBrush(point_color))
1211
+ dot.setPen(QPen(Qt.GlobalColor.transparent))
751
1212
  self.viewer.scene().addItem(dot)
752
1213
  self.polygon_preview_items.append(dot)
753
1214
  self.draw_polygon_preview()
754
1215
 
755
1216
  def draw_polygon_preview(self):
756
- if self.rubber_band_line:
757
- self.viewer.scene().removeItem(self.rubber_band_line)
758
- self.rubber_band_line = None
759
1217
  for item in self.polygon_preview_items:
760
1218
  if not isinstance(item, QGraphicsEllipseItem):
761
- self.viewer.scene().removeItem(item)
762
- self.polygon_preview_items = [item for item in self.polygon_preview_items if isinstance(item, QGraphicsEllipseItem)]
763
-
1219
+ if item.scene():
1220
+ self.viewer.scene().removeItem(item)
1221
+ self.polygon_preview_items = [
1222
+ item
1223
+ for item in self.polygon_preview_items
1224
+ if isinstance(item, QGraphicsEllipseItem)
1225
+ ]
764
1226
  if len(self.polygon_points) > 2:
765
1227
  preview_poly = QGraphicsPolygonItem(QPolygonF(self.polygon_points))
766
1228
  preview_poly.setBrush(QBrush(QColor(0, 255, 255, 100)))
@@ -769,6 +1231,8 @@ class MainWindow(QMainWindow):
769
1231
  self.polygon_preview_items.append(preview_poly)
770
1232
 
771
1233
  if len(self.polygon_points) > 1:
1234
+ line_color = QColor(Qt.GlobalColor.cyan)
1235
+ line_color.setAlpha(150)
772
1236
  for i in range(len(self.polygon_points) - 1):
773
1237
  line = QGraphicsLineItem(
774
1238
  self.polygon_points[i].x(),
@@ -776,7 +1240,7 @@ class MainWindow(QMainWindow):
776
1240
  self.polygon_points[i + 1].x(),
777
1241
  self.polygon_points[i + 1].y(),
778
1242
  )
779
- line.setPen(QPen(Qt.GlobalColor.cyan, 2))
1243
+ line.setPen(QPen(line_color, self.line_thickness))
780
1244
  self.viewer.scene().addItem(line)
781
1245
  self.polygon_preview_items.append(line)
782
1246