lazylabel-gui 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lazylabel/main.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
  import numpy as np
4
4
  import qdarktheme
5
5
  import cv2
6
+ import json
6
7
  from PyQt6.QtWidgets import (
7
8
  QApplication,
8
9
  QMainWindow,
@@ -15,6 +16,8 @@ from PyQt6.QtWidgets import (
15
16
  QTableWidgetItem,
16
17
  QGraphicsPolygonItem,
17
18
  QTableWidgetSelectionRange,
19
+ QSpacerItem,
20
+ QHeaderView,
18
21
  )
19
22
  from PyQt6.QtGui import (
20
23
  QPixmap,
@@ -60,14 +63,15 @@ class MainWindow(QMainWindow):
60
63
  self.current_file_index = QModelIndex()
61
64
 
62
65
  self.next_class_id = 0
66
+ self.class_aliases = {}
63
67
 
64
- self.class_aliases = {} # {class_id: "alias_string"}
68
+ self._original_point_radius = 0.3
69
+ self._original_line_thickness = 0.5
70
+ self.point_radius = self._original_point_radius
71
+ self.line_thickness = self._original_line_thickness
65
72
 
66
- self.point_radius = 0.3
67
- self.line_thickness = 0.5
68
-
69
- self._original_point_radius = self.point_radius
70
- self._original_line_thickness = self.line_thickness
73
+ self.pan_multiplier = 1.0
74
+ self.polygon_join_threshold = 2
71
75
 
72
76
  self.point_items, self.positive_points, self.negative_points = [], [], []
73
77
  self.polygon_points, self.polygon_preview_items = [], []
@@ -87,6 +91,10 @@ class MainWindow(QMainWindow):
87
91
  self.file_model = CustomFileSystemModel()
88
92
  self.right_panel.file_tree.setModel(self.file_model)
89
93
  self.right_panel.file_tree.setColumnWidth(0, 200)
94
+ file_tree = self.right_panel.file_tree
95
+ header = file_tree.header()
96
+ header.setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
97
+ header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
90
98
 
91
99
  main_layout = QHBoxLayout()
92
100
  main_layout.addWidget(self.control_panel)
@@ -101,6 +109,7 @@ class MainWindow(QMainWindow):
101
109
  )
102
110
  self.setup_connections()
103
111
  self.set_sam_mode()
112
+ self.set_annotation_size(10)
104
113
 
105
114
  def setup_connections(self):
106
115
  self._original_mouse_press = self.viewer.scene().mousePressEvent
@@ -122,7 +131,6 @@ class MainWindow(QMainWindow):
122
131
  self.right_panel.segment_table.itemSelectionChanged.connect(
123
132
  self.highlight_selected_segments
124
133
  )
125
- self.right_panel.segment_table.itemChanged.connect(self.handle_class_id_change)
126
134
  self.right_panel.class_table.itemChanged.connect(self.handle_alias_change)
127
135
  self.right_panel.btn_reassign_classes.clicked.connect(self.reassign_class_ids)
128
136
  self.right_panel.class_filter_combo.currentIndexChanged.connect(
@@ -137,11 +145,131 @@ class MainWindow(QMainWindow):
137
145
  self.control_panel.btn_clear_points.clicked.connect(self.clear_all_points)
138
146
  self.control_panel.btn_fit_view.clicked.connect(self.viewer.fitInView)
139
147
 
140
- # **FIX:** Use QShortcut for reliable global hotkeys
141
- next_shortcut = QShortcut(QKeySequence(Qt.Key.Key_Right), self)
142
- next_shortcut.activated.connect(self.load_next_image)
143
- prev_shortcut = QShortcut(QKeySequence(Qt.Key.Key_Left), self)
144
- prev_shortcut.activated.connect(self.load_previous_image)
148
+ self.control_panel.size_slider.valueChanged.connect(self.set_annotation_size)
149
+ self.control_panel.pan_slider.valueChanged.connect(self.set_pan_multiplier)
150
+ self.control_panel.join_slider.valueChanged.connect(
151
+ self.set_polygon_join_threshold
152
+ )
153
+
154
+ self.control_panel.chk_save_npz.stateChanged.connect(
155
+ self.handle_save_checkbox_change
156
+ )
157
+ self.control_panel.chk_save_txt.stateChanged.connect(
158
+ self.handle_save_checkbox_change
159
+ )
160
+
161
+ self.control_panel.btn_toggle_visibility.clicked.connect(self.toggle_left_panel)
162
+ self.right_panel.btn_toggle_visibility.clicked.connect(self.toggle_right_panel)
163
+
164
+ QShortcut(QKeySequence(Qt.Key.Key_Right), self, self.load_next_image)
165
+ QShortcut(QKeySequence(Qt.Key.Key_Left), self, self.load_previous_image)
166
+ QShortcut(QKeySequence(Qt.Key.Key_1), self, self.set_sam_mode)
167
+ QShortcut(QKeySequence(Qt.Key.Key_2), self, self.set_polygon_mode)
168
+ QShortcut(QKeySequence(Qt.Key.Key_E), self, self.toggle_selection_mode)
169
+ QShortcut(QKeySequence(Qt.Key.Key_Q), self, self.toggle_pan_mode)
170
+ QShortcut(QKeySequence(Qt.Key.Key_R), self, self.toggle_edit_mode)
171
+ QShortcut(QKeySequence(Qt.Key.Key_C), self, self.clear_all_points)
172
+ QShortcut(QKeySequence(Qt.Key.Key_Escape), self, self.handle_escape_press)
173
+ QShortcut(QKeySequence(Qt.Key.Key_V), self, self.delete_selected_segments)
174
+ QShortcut(
175
+ QKeySequence(Qt.Key.Key_Backspace), self, self.delete_selected_segments
176
+ )
177
+ QShortcut(QKeySequence(Qt.Key.Key_M), self, self.handle_merge_press)
178
+ QShortcut(QKeySequence("Ctrl+Z"), self, self.undo_last_action)
179
+ QShortcut(
180
+ QKeySequence("Ctrl+A"), self, self.right_panel.segment_table.selectAll
181
+ )
182
+ QShortcut(QKeySequence(Qt.Key.Key_Space), self, self.handle_space_press)
183
+ QShortcut(QKeySequence(Qt.Key.Key_Return), self, self.handle_enter_press)
184
+ QShortcut(QKeySequence(Qt.Key.Key_Enter), self, self.handle_enter_press)
185
+ QShortcut(QKeySequence(Qt.Key.Key_Period), self, self.viewer.fitInView)
186
+
187
+ def toggle_left_panel(self):
188
+ is_visible = self.control_panel.main_controls_widget.isVisible()
189
+ self.control_panel.main_controls_widget.setVisible(not is_visible)
190
+ if is_visible:
191
+ self.control_panel.btn_toggle_visibility.setText("> Show")
192
+ self.control_panel.setFixedWidth(
193
+ self.control_panel.btn_toggle_visibility.sizeHint().width() + 20
194
+ )
195
+ else:
196
+ self.control_panel.btn_toggle_visibility.setText("< Hide")
197
+ self.control_panel.setFixedWidth(250)
198
+
199
+ def toggle_right_panel(self):
200
+ is_visible = self.right_panel.main_controls_widget.isVisible()
201
+ self.right_panel.main_controls_widget.setVisible(not is_visible)
202
+ layout = self.right_panel.v_layout
203
+
204
+ if is_visible: # Content is now hidden
205
+ layout.addStretch(1)
206
+ self.right_panel.btn_toggle_visibility.setText("< Show")
207
+ self.right_panel.setFixedWidth(
208
+ self.right_panel.btn_toggle_visibility.sizeHint().width() + 20
209
+ )
210
+ else: # Content is now visible
211
+ # Remove the stretch so the content can expand
212
+ for i in range(layout.count()):
213
+ item = layout.itemAt(i)
214
+ if isinstance(item, QSpacerItem):
215
+ layout.removeItem(item)
216
+ break
217
+ self.right_panel.btn_toggle_visibility.setText("Hide >")
218
+ self.right_panel.setFixedWidth(350)
219
+
220
+ def handle_save_checkbox_change(self):
221
+ is_npz_checked = self.control_panel.chk_save_npz.isChecked()
222
+ is_txt_checked = self.control_panel.chk_save_txt.isChecked()
223
+
224
+ if not is_npz_checked and not is_txt_checked:
225
+ sender = self.sender()
226
+ if sender == self.control_panel.chk_save_npz:
227
+ self.control_panel.chk_save_txt.setChecked(True)
228
+ else:
229
+ self.control_panel.chk_save_npz.setChecked(True)
230
+
231
+ def set_annotation_size(self, value):
232
+ multiplier = value / 10.0
233
+ self.point_radius = self._original_point_radius * multiplier
234
+ self.line_thickness = self._original_line_thickness * multiplier
235
+
236
+ self.control_panel.size_label.setText(f"Annotation Size: {multiplier:.1f}x")
237
+
238
+ if self.control_panel.size_slider.value() != value:
239
+ self.control_panel.size_slider.setValue(value)
240
+
241
+ self.display_all_segments()
242
+ self.clear_all_points()
243
+
244
+ def set_pan_multiplier(self, value):
245
+ self.pan_multiplier = value / 10.0
246
+ self.control_panel.pan_label.setText(f"Pan Speed: {self.pan_multiplier:.1f}x")
247
+
248
+ def set_polygon_join_threshold(self, value):
249
+ self.polygon_join_threshold = value
250
+ self.control_panel.join_label.setText(f"Polygon Join Distance: {value}px")
251
+
252
+ def handle_escape_press(self):
253
+ self.right_panel.segment_table.clearSelection()
254
+ self.right_panel.class_table.clearSelection()
255
+ self.clear_all_points()
256
+ self.viewer.setFocus()
257
+
258
+ def handle_space_press(self):
259
+ if self.mode == "polygon" and self.polygon_points:
260
+ self.finalize_polygon()
261
+ else:
262
+ self.save_current_segment()
263
+
264
+ def handle_enter_press(self):
265
+ if self.mode == "polygon" and self.polygon_points:
266
+ self.finalize_polygon()
267
+ else:
268
+ self.save_output_to_npz()
269
+
270
+ def handle_merge_press(self):
271
+ self.assign_selected_to_class()
272
+ self.right_panel.segment_table.clearSelection()
145
273
 
146
274
  def show_notification(self, message, duration=3000):
147
275
  self.control_panel.notification_label.setText(message)
@@ -152,10 +280,8 @@ class MainWindow(QMainWindow):
152
280
  def _get_color_for_class(self, class_id):
153
281
  if class_id is None:
154
282
  return QColor.fromHsv(0, 0, 128)
155
-
156
283
  hue = int((class_id * 222.4922359) % 360)
157
284
  color = QColor.fromHsv(hue, 220, 220)
158
-
159
285
  if not color.isValid():
160
286
  return QColor(Qt.GlobalColor.white)
161
287
  return color
@@ -256,6 +382,7 @@ class MainWindow(QMainWindow):
256
382
  self.reset_state()
257
383
  self.viewer.set_photo(pixmap)
258
384
  self.sam_model.set_image(self.current_image_path)
385
+ self.load_class_aliases()
259
386
  self.load_existing_mask()
260
387
  self.right_panel.file_tree.setCurrentIndex(index)
261
388
  self.viewer.setFocus()
@@ -288,9 +415,8 @@ class MainWindow(QMainWindow):
288
415
 
289
416
  def reset_state(self):
290
417
  self.clear_all_points()
291
- # Preserve aliases between images in the same session
292
- # self.class_aliases.clear()
293
418
  self.segments.clear()
419
+ self.class_aliases.clear()
294
420
  self.next_class_id = 0
295
421
  self.update_all_lists()
296
422
  items_to_remove = [
@@ -305,83 +431,55 @@ class MainWindow(QMainWindow):
305
431
 
306
432
  def keyPressEvent(self, event):
307
433
  key, mods = event.key(), event.modifiers()
308
- if event.isAutoRepeat():
434
+
435
+ if event.isAutoRepeat() and key not in {
436
+ Qt.Key.Key_W,
437
+ Qt.Key.Key_A,
438
+ Qt.Key.Key_S,
439
+ Qt.Key.Key_D,
440
+ }:
309
441
  return
310
442
 
311
- pan_multiplier = 5.0 if (mods & Qt.KeyboardModifier.ShiftModifier) else 2
443
+ shift_multiplier = 5.0 if mods & Qt.KeyboardModifier.ShiftModifier else 1.0
312
444
 
313
445
  if key == Qt.Key.Key_W:
314
- amount = int(self.viewer.height() * 0.1 * pan_multiplier)
446
+ amount = int(
447
+ self.viewer.height() * 0.1 * self.pan_multiplier * shift_multiplier
448
+ )
315
449
  self.viewer.verticalScrollBar().setValue(
316
450
  self.viewer.verticalScrollBar().value() - amount
317
451
  )
318
452
  elif key == Qt.Key.Key_S:
319
- amount = int(self.viewer.height() * 0.1 * pan_multiplier)
453
+ amount = int(
454
+ self.viewer.height() * 0.1 * self.pan_multiplier * shift_multiplier
455
+ )
320
456
  self.viewer.verticalScrollBar().setValue(
321
457
  self.viewer.verticalScrollBar().value() + amount
322
458
  )
323
- elif key == Qt.Key.Key_A and not (mods & Qt.KeyboardModifier.ControlModifier):
324
- amount = int(self.viewer.width() * 0.1 * pan_multiplier)
459
+ elif key == Qt.Key.Key_A:
460
+ amount = int(
461
+ self.viewer.width() * 0.1 * self.pan_multiplier * shift_multiplier
462
+ )
325
463
  self.viewer.horizontalScrollBar().setValue(
326
464
  self.viewer.horizontalScrollBar().value() - amount
327
465
  )
328
466
  elif key == Qt.Key.Key_D:
329
- amount = int(self.viewer.width() * 0.1 * pan_multiplier)
467
+ amount = int(
468
+ self.viewer.width() * 0.1 * self.pan_multiplier * shift_multiplier
469
+ )
330
470
  self.viewer.horizontalScrollBar().setValue(
331
471
  self.viewer.horizontalScrollBar().value() + amount
332
472
  )
333
- elif key == Qt.Key.Key_Period:
334
- self.viewer.fitInView()
335
- # Other keybindings
336
- elif key == Qt.Key.Key_1:
337
- self.set_sam_mode()
338
- elif key == Qt.Key.Key_2:
339
- self.set_polygon_mode()
340
- elif key == Qt.Key.Key_E:
341
- self.toggle_selection_mode()
342
- elif key == Qt.Key.Key_Q:
343
- self.toggle_pan_mode()
344
- elif key == Qt.Key.Key_R:
345
- self.toggle_edit_mode()
346
- elif key == Qt.Key.Key_C or key == Qt.Key.Key_Escape:
347
- self.clear_all_points()
348
- elif key == Qt.Key.Key_V or key == Qt.Key.Key_Backspace:
349
- self.delete_selected_segments()
350
- elif key == Qt.Key.Key_M:
351
- self.assign_selected_to_class()
352
- self.right_panel.segment_table.clearSelection()
353
- elif key == Qt.Key.Key_Z and mods == Qt.KeyboardModifier.ControlModifier:
354
- self.undo_last_action()
355
- elif key == Qt.Key.Key_A and mods == Qt.KeyboardModifier.ControlModifier:
356
- self.right_panel.segment_table.selectAll()
357
- elif key == Qt.Key.Key_Space:
358
- if self.mode == "polygon" and self.polygon_points:
359
- self.finalize_polygon()
360
- else:
361
- self.save_current_segment()
362
- elif key == Qt.Key.Key_Return or key == Qt.Key.Key_Enter:
363
- if self.mode == "polygon" and self.polygon_points:
364
- self.finalize_polygon()
365
- else:
366
- self.save_output_to_npz()
367
473
  elif (
368
474
  key == Qt.Key.Key_Equal or key == Qt.Key.Key_Plus
369
475
  ) and mods == Qt.KeyboardModifier.ControlModifier:
370
- self.point_radius = min(20, self.point_radius + self._original_point_radius)
371
- self.line_thickness = min(
372
- 20, self.line_thickness + self._original_line_thickness
373
- )
374
- self.display_all_segments()
375
- self.clear_all_points()
476
+ current_val = self.control_panel.size_slider.value()
477
+ self.control_panel.size_slider.setValue(current_val + 1)
376
478
  elif key == Qt.Key.Key_Minus and mods == Qt.KeyboardModifier.ControlModifier:
377
- self.point_radius = max(
378
- 0.3, self.point_radius - self._original_point_radius
379
- )
380
- self.line_thickness = max(
381
- 0.5, self.line_thickness - self._original_line_thickness
382
- )
383
- self.display_all_segments()
384
- self.clear_all_points()
479
+ current_val = self.control_panel.size_slider.value()
480
+ self.control_panel.size_slider.setValue(current_val - 1)
481
+ else:
482
+ super().keyPressEvent(event)
385
483
 
386
484
  def scene_mouse_press(self, event):
387
485
  self._original_mouse_press(event)
@@ -432,15 +530,12 @@ class MainWindow(QMainWindow):
432
530
  elif self.mode == "polygon" and self.polygon_points:
433
531
  if self.rubber_band_line is None:
434
532
  self.rubber_band_line = QGraphicsLineItem()
435
-
436
533
  line_color = QColor(Qt.GlobalColor.white)
437
534
  line_color.setAlpha(150)
438
-
439
535
  self.rubber_band_line.setPen(
440
536
  QPen(line_color, self.line_thickness, Qt.PenStyle.DotLine)
441
537
  )
442
538
  self.viewer.scene().addItem(self.rubber_band_line)
443
-
444
539
  self.rubber_band_line.setLine(
445
540
  self.polygon_points[-1].x(),
446
541
  self.polygon_points[-1].y(),
@@ -454,7 +549,6 @@ class MainWindow(QMainWindow):
454
549
  def scene_mouse_release(self, event):
455
550
  if self.mode == "pan":
456
551
  self.viewer.set_cursor(Qt.CursorShape.OpenHandCursor)
457
-
458
552
  if self.mode == "edit" and self.is_dragging_polygon:
459
553
  self.is_dragging_polygon = False
460
554
  self.drag_initial_vertices.clear()
@@ -463,12 +557,10 @@ class MainWindow(QMainWindow):
463
557
  def undo_last_action(self):
464
558
  if self.mode == "polygon" and self.polygon_points:
465
559
  self.polygon_points.pop()
466
-
467
560
  for item in self.polygon_preview_items:
468
561
  if item.scene():
469
562
  self.viewer.scene().removeItem(item)
470
563
  self.polygon_preview_items.clear()
471
-
472
564
  for point in self.polygon_points:
473
565
  point_diameter = self.point_radius * 2
474
566
  point_color = QColor(Qt.GlobalColor.blue)
@@ -483,21 +575,17 @@ class MainWindow(QMainWindow):
483
575
  dot.setPen(QPen(Qt.GlobalColor.transparent))
484
576
  self.viewer.scene().addItem(dot)
485
577
  self.polygon_preview_items.append(dot)
486
-
487
578
  self.draw_polygon_preview()
488
-
489
579
  elif self.mode == "sam_points" and self.point_items:
490
580
  item_to_remove = self.point_items.pop()
491
581
  point_pos = item_to_remove.rect().topLeft() + QPointF(
492
582
  self.point_radius, self.point_radius
493
583
  )
494
584
  point_coords = [int(point_pos.x()), int(point_pos.y())]
495
-
496
585
  if point_coords in self.positive_points:
497
586
  self.positive_points.remove(point_coords)
498
587
  elif point_coords in self.negative_points:
499
588
  self.negative_points.remove(point_coords)
500
-
501
589
  self.viewer.scene().removeItem(item_to_remove)
502
590
  self.update_segmentation()
503
591
 
@@ -574,14 +662,13 @@ class MainWindow(QMainWindow):
574
662
  if existing_class_ids:
575
663
  target_class_id = min(existing_class_ids)
576
664
  else:
577
- target_class_id = self.segments[selected_indices[0]].get("class_id")
665
+ target_class_id = self.next_class_id
578
666
 
579
667
  for i in selected_indices:
580
668
  self.segments[i]["class_id"] = target_class_id
581
669
 
582
670
  self._update_next_class_id()
583
671
  self.update_all_lists()
584
- self.right_panel.segment_table.clearSelection()
585
672
  self.viewer.setFocus()
586
673
 
587
674
  def rasterize_polygon(self, vertices):
@@ -606,7 +693,6 @@ class MainWindow(QMainWindow):
606
693
  for i, seg_dict in enumerate(self.segments):
607
694
  self.segment_items[i] = []
608
695
  class_id = seg_dict.get("class_id")
609
-
610
696
  base_color = self._get_color_for_class(class_id)
611
697
 
612
698
  if seg_dict["type"] == "Polygon":
@@ -621,7 +707,6 @@ class MainWindow(QMainWindow):
621
707
  poly_item.setPen(QPen(Qt.GlobalColor.transparent))
622
708
  self.viewer.scene().addItem(poly_item)
623
709
  self.segment_items[i].append(poly_item)
624
-
625
710
  base_color.setAlpha(150)
626
711
  vertex_color = QBrush(base_color)
627
712
  point_diameter = self.point_radius * 2
@@ -658,7 +743,6 @@ class MainWindow(QMainWindow):
658
743
  hover_pixmap = mask_to_pixmap(
659
744
  seg_dict["mask"], base_color.getRgb()[:3], alpha=170
660
745
  )
661
-
662
746
  pixmap_item = HoverablePixmapItem()
663
747
  pixmap_item.set_pixmaps(default_pixmap, hover_pixmap)
664
748
  self.viewer.scene().addItem(pixmap_item)
@@ -697,7 +781,7 @@ class MainWindow(QMainWindow):
697
781
  self.highlight_items.append(highlight_item)
698
782
 
699
783
  def update_all_lists(self):
700
- self.update_class_list() # Must be before filter combo
784
+ self.update_class_list()
701
785
  self.update_class_filter_combo()
702
786
  self.update_segment_table()
703
787
  self.display_all_segments()
@@ -727,19 +811,24 @@ class MainWindow(QMainWindow):
727
811
  for row, (original_index, seg) in enumerate(display_segments):
728
812
  class_id = seg.get("class_id")
729
813
  color = self._get_color_for_class(class_id)
730
-
731
814
  class_id_str = str(class_id) if class_id is not None else "N/A"
815
+
816
+ alias_str = "N/A"
817
+ if class_id is not None:
818
+ alias_str = self.class_aliases.get(class_id, str(class_id))
819
+ alias_item = QTableWidgetItem(alias_str)
820
+
732
821
  index_item = NumericTableWidgetItem(str(original_index + 1))
733
822
  class_item = NumericTableWidgetItem(class_id_str)
734
- type_item = QTableWidgetItem(seg.get("type", "N/A"))
735
823
 
736
824
  index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
737
- type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
825
+ class_item.setFlags(class_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
826
+ alias_item.setFlags(alias_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
738
827
  index_item.setData(Qt.ItemDataRole.UserRole, original_index)
739
828
 
740
829
  table.setItem(row, 0, index_item)
741
830
  table.setItem(row, 1, class_item)
742
- table.setItem(row, 2, type_item)
831
+ table.setItem(row, 2, alias_item)
743
832
 
744
833
  for col in range(table.columnCount()):
745
834
  if table.item(row, col):
@@ -759,19 +848,7 @@ class MainWindow(QMainWindow):
759
848
  class_table = self.right_panel.class_table
760
849
  class_table.blockSignals(True)
761
850
 
762
- # Preserve existing aliases during update
763
- current_aliases = {}
764
- for row in range(class_table.rowCount()):
765
- try:
766
- alias = class_table.item(row, 0).text()
767
- cid = int(class_table.item(row, 1).text())
768
- current_aliases[cid] = alias
769
- except (AttributeError, ValueError):
770
- continue
771
- self.class_aliases.update(current_aliases)
772
-
773
- class_table.clearContents()
774
-
851
+ preserved_aliases = self.class_aliases.copy()
775
852
  unique_class_ids = sorted(
776
853
  list(
777
854
  {
@@ -781,19 +858,22 @@ class MainWindow(QMainWindow):
781
858
  }
782
859
  )
783
860
  )
784
- class_table.setRowCount(len(unique_class_ids))
785
861
 
862
+ new_aliases = {}
863
+ for cid in unique_class_ids:
864
+ new_aliases[cid] = preserved_aliases.get(cid, str(cid))
865
+
866
+ self.class_aliases = new_aliases
867
+
868
+ class_table.clearContents()
869
+ class_table.setRowCount(len(unique_class_ids))
786
870
  for row, cid in enumerate(unique_class_ids):
787
- alias = self.class_aliases.get(cid, str(cid))
788
- alias_item = QTableWidgetItem(alias)
871
+ alias_item = QTableWidgetItem(self.class_aliases.get(cid))
789
872
  id_item = QTableWidgetItem(str(cid))
790
-
791
873
  id_item.setFlags(id_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
792
-
793
874
  color = self._get_color_for_class(cid)
794
875
  alias_item.setBackground(QBrush(color))
795
876
  id_item.setBackground(QBrush(color))
796
-
797
877
  class_table.setItem(row, 0, alias_item)
798
878
  class_table.setItem(row, 1, id_item)
799
879
 
@@ -810,7 +890,6 @@ class MainWindow(QMainWindow):
810
890
  }
811
891
  )
812
892
  )
813
-
814
893
  current_selection = combo.currentText()
815
894
  combo.blockSignals(True)
816
895
  combo.clear()
@@ -821,7 +900,6 @@ class MainWindow(QMainWindow):
821
900
  for cid in unique_class_ids
822
901
  ]
823
902
  )
824
-
825
903
  if combo.findText(current_selection) > -1:
826
904
  combo.setCurrentText(current_selection)
827
905
  else:
@@ -830,7 +908,6 @@ class MainWindow(QMainWindow):
830
908
 
831
909
  def reassign_class_ids(self):
832
910
  class_table = self.right_panel.class_table
833
-
834
911
  ordered_ids = []
835
912
  for row in range(class_table.rowCount()):
836
913
  id_item = class_table.item(row, 1)
@@ -839,75 +916,37 @@ class MainWindow(QMainWindow):
839
916
  ordered_ids.append(int(id_item.text()))
840
917
  except ValueError:
841
918
  continue
842
-
843
919
  id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
844
-
845
920
  for seg in self.segments:
846
921
  old_id = seg.get("class_id")
847
922
  if old_id in id_map:
848
923
  seg["class_id"] = id_map[old_id]
849
-
850
924
  new_aliases = {
851
925
  id_map[old_id]: self.class_aliases.get(old_id, str(old_id))
852
926
  for old_id in ordered_ids
853
927
  if old_id in self.class_aliases
854
928
  }
855
929
  self.class_aliases = new_aliases
856
-
857
930
  self._update_next_class_id()
858
931
  self.update_all_lists()
859
932
  self.viewer.setFocus()
860
933
 
861
934
  def handle_alias_change(self, item):
862
- if item.column() != 0: # Alias column
935
+ if item.column() != 0:
863
936
  return
864
-
865
937
  class_table = self.right_panel.class_table
866
938
  class_table.blockSignals(True)
867
-
868
939
  id_item = class_table.item(item.row(), 1)
869
940
  if id_item:
870
941
  try:
871
942
  class_id = int(id_item.text())
872
943
  self.class_aliases[class_id] = item.text()
873
944
  except (ValueError, AttributeError):
874
- pass # Ignore if ID item is not valid
875
-
945
+ pass
876
946
  class_table.blockSignals(False)
877
- self.update_class_filter_combo() # Refresh filter to show new alias
878
-
879
- def handle_class_id_change(self, item):
880
- if item.column() != 1: # Class ID column in segment table
881
- return
882
- table = self.right_panel.segment_table
883
- index_item = table.item(item.row(), 0)
884
- if not index_item:
885
- return
886
-
887
- table.blockSignals(True)
888
- try:
889
- new_class_id_text = item.text()
890
- if not new_class_id_text.strip():
891
- raise ValueError("Class ID cannot be empty.")
892
- new_class_id = int(new_class_id_text)
893
- original_index = index_item.data(Qt.ItemDataRole.UserRole)
894
947
 
895
- if original_index is None or original_index >= len(self.segments):
896
- raise IndexError("Invalid segment index found in table.")
897
-
898
- self.segments[original_index]["class_id"] = new_class_id
899
- self._update_next_class_id()
900
- self.update_all_lists()
901
- except (ValueError, TypeError, AttributeError, IndexError) as e:
902
- original_index = index_item.data(Qt.ItemDataRole.UserRole)
903
- if original_index is not None and original_index < len(self.segments):
904
- original_class_id = self.segments[original_index].get("class_id")
905
- item.setText(
906
- str(original_class_id) if original_class_id is not None else "N/A"
907
- )
908
- finally:
909
- table.blockSignals(False)
910
- self.viewer.setFocus()
948
+ self.update_class_filter_combo()
949
+ self.update_segment_table()
911
950
 
912
951
  def get_selected_segment_indices(self):
913
952
  table = self.right_panel.segment_table
@@ -920,95 +959,136 @@ class MainWindow(QMainWindow):
920
959
  ]
921
960
 
922
961
  def save_output_to_npz(self):
923
- if not self.segments or not self.current_image_path:
962
+ save_npz = self.control_panel.chk_save_npz.isChecked()
963
+ save_txt = self.control_panel.chk_save_txt.isChecked()
964
+ save_aliases = self.control_panel.chk_save_class_aliases.isChecked()
965
+
966
+ if not self.current_image_path or not any([save_npz, save_txt, save_aliases]):
924
967
  return
968
+
925
969
  self.right_panel.status_label.setText("Saving...")
926
970
  QApplication.processEvents()
927
971
 
928
- output_path = os.path.splitext(self.current_image_path)[0] + ".npz"
929
- h, w = (
930
- self.viewer._pixmap_item.pixmap().height(),
931
- self.viewer._pixmap_item.pixmap().width(),
932
- )
933
-
934
- class_table = self.right_panel.class_table
935
- ordered_ids = [
936
- int(class_table.item(row, 1).text())
937
- for row in range(class_table.rowCount())
938
- if class_table.item(row, 1) is not None
939
- ]
940
-
941
- if not ordered_ids:
942
- self.right_panel.status_label.setText("Save failed: No classes defined.")
943
- QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
944
- return
945
-
946
- id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
947
- num_final_classes = len(ordered_ids)
948
- final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
972
+ saved_something = False
949
973
 
950
- for seg in self.segments:
951
- class_id = seg.get("class_id")
952
- if class_id not in id_map:
953
- continue
954
- new_channel_idx = id_map[class_id]
955
- mask = (
956
- self.rasterize_polygon(seg["vertices"])
957
- if seg["type"] == "Polygon"
958
- else seg.get("mask")
959
- )
960
- if mask is not None:
961
- final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
962
- final_mask_tensor[:, :, new_channel_idx], mask
974
+ if save_npz or save_txt:
975
+ if not self.segments:
976
+ self.show_notification("No segments to save.")
977
+ else:
978
+ h, w = (
979
+ self.viewer._pixmap_item.pixmap().height(),
980
+ self.viewer._pixmap_item.pixmap().width(),
963
981
  )
982
+ class_table = self.right_panel.class_table
983
+ ordered_ids = [
984
+ int(class_table.item(row, 1).text())
985
+ for row in range(class_table.rowCount())
986
+ if class_table.item(row, 1) is not None
987
+ ]
988
+
989
+ if not ordered_ids:
990
+ self.show_notification("No classes defined for mask saving.")
991
+ else:
992
+ id_map = {
993
+ old_id: new_id for new_id, old_id in enumerate(ordered_ids)
994
+ }
995
+ num_final_classes = len(ordered_ids)
996
+ final_mask_tensor = np.zeros(
997
+ (h, w, num_final_classes), dtype=np.uint8
998
+ )
964
999
 
965
- np.savez_compressed(output_path, mask=final_mask_tensor.astype(np.uint8))
1000
+ for seg in self.segments:
1001
+ class_id = seg.get("class_id")
1002
+ if class_id not in id_map:
1003
+ continue
1004
+ new_channel_idx = id_map[class_id]
1005
+ mask = (
1006
+ self.rasterize_polygon(seg["vertices"])
1007
+ if seg["type"] == "Polygon"
1008
+ else seg.get("mask")
1009
+ )
1010
+ if mask is not None:
1011
+ final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
1012
+ final_mask_tensor[:, :, new_channel_idx], mask
1013
+ )
1014
+ if save_npz:
1015
+ npz_path = os.path.splitext(self.current_image_path)[0] + ".npz"
1016
+ np.savez_compressed(
1017
+ npz_path, mask=final_mask_tensor.astype(np.uint8)
1018
+ )
1019
+ self.file_model.update_cache_for_path(npz_path)
966
1020
 
967
- self.file_model.set_highlighted_path(output_path)
968
- QTimer.singleShot(1500, lambda: self.file_model.set_highlighted_path(None))
1021
+ self.file_model.set_highlighted_path(npz_path)
1022
+ QTimer.singleShot(
1023
+ 1500, lambda: self.file_model.set_highlighted_path(None)
1024
+ )
1025
+ saved_something = True
1026
+ if save_txt:
1027
+ if self.control_panel.chk_yolo_use_alias.isChecked():
1028
+ class_labels = [
1029
+ class_table.item(row, 0).text()
1030
+ for row in range(class_table.rowCount())
1031
+ ]
1032
+ else:
1033
+ class_labels = list(range(num_final_classes))
1034
+
1035
+ txt_path = self.generate_yolo_annotations(
1036
+ final_mask_tensor, class_labels
1037
+ )
1038
+ if txt_path:
1039
+ self.file_model.update_cache_for_path(txt_path)
1040
+ saved_something = True
1041
+
1042
+ if save_aliases:
1043
+ aliases_path = os.path.splitext(self.current_image_path)[0] + ".json"
1044
+ aliases_to_save = {str(k): v for k, v in self.class_aliases.items()}
1045
+ with open(aliases_path, "w") as f:
1046
+ json.dump(aliases_to_save, f, indent=4)
1047
+ saved_something = True
1048
+
1049
+ if saved_something:
1050
+ self.right_panel.status_label.setText("Saved!")
1051
+ else:
1052
+ self.right_panel.status_label.clear()
969
1053
 
970
- self.right_panel.status_label.setText("Saved!")
971
- self.generate_yolo_annotations(npz_file_path=output_path)
972
1054
  QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
973
1055
 
974
- def generate_yolo_annotations(self, npz_file_path):
1056
+ def generate_yolo_annotations(self, mask_tensor, class_labels):
975
1057
  output_path = os.path.splitext(self.current_image_path)[0] + ".txt"
976
- npz_data = np.load(npz_file_path) # Load the saved npz file
977
-
978
- img = npz_data["mask"][:, :, :]
979
- num_channels = img.shape[2] # C
980
- h, w = img.shape[:2] # H, W
1058
+ h, w, num_channels = mask_tensor.shape
981
1059
 
982
1060
  directory_path = os.path.dirname(output_path)
983
1061
  os.makedirs(directory_path, exist_ok=True)
984
1062
 
985
1063
  yolo_annotations = []
986
-
987
1064
  for channel in range(num_channels):
988
- single_channel_image = img[:, :, channel]
1065
+ single_channel_image = mask_tensor[:, :, channel]
1066
+ if not np.any(single_channel_image):
1067
+ continue
1068
+
989
1069
  contours, _ = cv2.findContours(
990
1070
  single_channel_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
991
1071
  )
992
1072
 
993
- class_id = channel # Use the channel index as the class ID
994
-
1073
+ class_label = class_labels[channel]
995
1074
  for contour in contours:
996
1075
  x, y, width, height = cv2.boundingRect(contour)
997
- center_x = x + width / 2
998
- center_y = y + height / 2
999
-
1000
- normalized_center_x = center_x / w
1001
- normalized_center_y = center_y / h
1076
+ center_x = (x + width / 2) / w
1077
+ center_y = (y + height / 2) / h
1002
1078
  normalized_width = width / w
1003
1079
  normalized_height = height / h
1004
-
1005
- yolo_entry = f"{class_id} {normalized_center_x} {normalized_center_y} {normalized_width} {normalized_height}"
1080
+ yolo_entry = f"{class_label} {center_x} {center_y} {normalized_width} {normalized_height}"
1006
1081
  yolo_annotations.append(yolo_entry)
1007
1082
 
1083
+ if not yolo_annotations:
1084
+ return None
1085
+
1008
1086
  with open(output_path, "w") as file:
1009
1087
  for annotation in yolo_annotations:
1010
1088
  file.write(annotation + "\n")
1011
1089
 
1090
+ return output_path
1091
+
1012
1092
  def save_current_segment(self):
1013
1093
  if (
1014
1094
  self.mode != "sam_points"
@@ -1040,6 +1120,20 @@ class MainWindow(QMainWindow):
1040
1120
  self.update_all_lists()
1041
1121
  self.viewer.setFocus()
1042
1122
 
1123
+ def load_class_aliases(self):
1124
+ if not self.current_image_path:
1125
+ return
1126
+ json_path = os.path.splitext(self.current_image_path)[0] + ".json"
1127
+ if os.path.exists(json_path):
1128
+ try:
1129
+ with open(json_path, "r") as f:
1130
+ loaded_aliases = json.load(f)
1131
+ # JSON loads keys as strings, convert them to int
1132
+ self.class_aliases = {int(k): v for k, v in loaded_aliases.items()}
1133
+ except (json.JSONDecodeError, ValueError) as e:
1134
+ print(f"Error loading class aliases from {json_path}: {e}")
1135
+ self.class_aliases.clear()
1136
+
1043
1137
  def load_existing_mask(self):
1044
1138
  if not self.current_image_path:
1045
1139
  return
@@ -1068,12 +1162,10 @@ class MainWindow(QMainWindow):
1068
1162
  def add_point(self, pos, positive):
1069
1163
  point_list = self.positive_points if positive else self.negative_points
1070
1164
  point_list.append([int(pos.x()), int(pos.y())])
1071
-
1072
1165
  point_color = (
1073
1166
  QColor(Qt.GlobalColor.green) if positive else QColor(Qt.GlobalColor.red)
1074
1167
  )
1075
1168
  point_color.setAlpha(150)
1076
-
1077
1169
  point_diameter = self.point_radius * 2
1078
1170
  point_item = QGraphicsEllipseItem(
1079
1171
  pos.x() - self.point_radius,
@@ -1120,17 +1212,15 @@ class MainWindow(QMainWindow):
1120
1212
  (pos.x() - self.polygon_points[0].x()) ** 2
1121
1213
  + (pos.y() - self.polygon_points[0].y()) ** 2
1122
1214
  )
1123
- < 4 # pixel distance threshold squared
1215
+ < self.polygon_join_threshold**2
1124
1216
  ):
1125
1217
  if len(self.polygon_points) > 2:
1126
1218
  self.finalize_polygon()
1127
1219
  return
1128
1220
  self.polygon_points.append(pos)
1129
1221
  point_diameter = self.point_radius * 2
1130
-
1131
1222
  point_color = QColor(Qt.GlobalColor.blue)
1132
1223
  point_color.setAlpha(150)
1133
-
1134
1224
  dot = QGraphicsEllipseItem(
1135
1225
  pos.x() - self.point_radius,
1136
1226
  pos.y() - self.point_radius,
@@ -1144,7 +1234,6 @@ class MainWindow(QMainWindow):
1144
1234
  self.draw_polygon_preview()
1145
1235
 
1146
1236
  def draw_polygon_preview(self):
1147
- # Clean up old preview lines/polygons
1148
1237
  for item in self.polygon_preview_items:
1149
1238
  if not isinstance(item, QGraphicsEllipseItem):
1150
1239
  if item.scene():
@@ -1154,7 +1243,6 @@ class MainWindow(QMainWindow):
1154
1243
  for item in self.polygon_preview_items
1155
1244
  if isinstance(item, QGraphicsEllipseItem)
1156
1245
  ]
1157
-
1158
1246
  if len(self.polygon_points) > 2:
1159
1247
  preview_poly = QGraphicsPolygonItem(QPolygonF(self.polygon_points))
1160
1248
  preview_poly.setBrush(QBrush(QColor(0, 255, 255, 100)))