lazylabel-gui 1.0.6__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lazylabel/main.py CHANGED
@@ -16,8 +16,18 @@ from PyQt6.QtWidgets import (
16
16
  QGraphicsPolygonItem,
17
17
  QTableWidgetSelectionRange,
18
18
  )
19
- from PyQt6.QtGui import QPixmap, QColor, QPen, QBrush, QPolygonF, QIcon
20
- from PyQt6.QtCore import Qt, QPointF, QTimer
19
+ from PyQt6.QtGui import (
20
+ QPixmap,
21
+ QColor,
22
+ QPen,
23
+ QBrush,
24
+ QPolygonF,
25
+ QIcon,
26
+ QCursor,
27
+ QKeySequence,
28
+ QShortcut,
29
+ )
30
+ from PyQt6.QtCore import Qt, QPointF, QTimer, QModelIndex
21
31
 
22
32
  from .photo_viewer import PhotoViewer
23
33
  from .sam_model import SamModel
@@ -26,6 +36,7 @@ from .controls import ControlPanel, RightPanel
26
36
  from .custom_file_system_model import CustomFileSystemModel
27
37
  from .editable_vertex import EditableVertexItem
28
38
  from .hoverable_polygon_item import HoverablePolygonItem
39
+ from .hoverable_pixelmap_item import HoverablePixmapItem
29
40
  from .numeric_table_widget_item import NumericTableWidgetItem
30
41
 
31
42
 
@@ -34,7 +45,9 @@ class MainWindow(QMainWindow):
34
45
  super().__init__()
35
46
  self.setWindowTitle("LazyLabel by DNC")
36
47
 
37
- icon_path = os.path.join(os.path.dirname(__file__), "demo_pictures", "logo2.png")
48
+ icon_path = os.path.join(
49
+ os.path.dirname(__file__), "demo_pictures", "logo2.png"
50
+ )
38
51
  if os.path.exists(icon_path):
39
52
  self.setWindowIcon(QIcon(icon_path))
40
53
 
@@ -44,9 +57,18 @@ class MainWindow(QMainWindow):
44
57
  self.mode = "sam_points"
45
58
  self.previous_mode = "sam_points"
46
59
  self.current_image_path = None
47
- self.current_file_index = None
60
+ self.current_file_index = QModelIndex()
61
+
48
62
  self.next_class_id = 0
49
63
 
64
+ self.class_aliases = {} # {class_id: "alias_string"}
65
+
66
+ self.point_radius = 0.3
67
+ self.line_thickness = 0.5
68
+
69
+ self._original_point_radius = self.point_radius
70
+ self._original_line_thickness = self.line_thickness
71
+
50
72
  self.point_items, self.positive_points, self.negative_points = [], [], []
51
73
  self.polygon_points, self.polygon_preview_items = [], []
52
74
  self.rubber_band_line = None
@@ -74,7 +96,9 @@ class MainWindow(QMainWindow):
74
96
  central_widget.setLayout(main_layout)
75
97
  self.setCentralWidget(central_widget)
76
98
 
77
- self.control_panel.device_label.setText(f"Device: {str(self.sam_model.device).upper()}")
99
+ self.control_panel.device_label.setText(
100
+ f"Device: {str(self.sam_model.device).upper()}"
101
+ )
78
102
  self.setup_connections()
79
103
  self.set_sam_mode()
80
104
 
@@ -89,24 +113,48 @@ class MainWindow(QMainWindow):
89
113
 
90
114
  self.right_panel.btn_open_folder.clicked.connect(self.open_folder_dialog)
91
115
  self.right_panel.file_tree.doubleClicked.connect(self.load_selected_image)
92
- self.right_panel.btn_merge_selection.clicked.connect(self.assign_selected_to_class)
93
- self.right_panel.btn_delete_selection.clicked.connect(self.delete_selected_segments)
94
- self.right_panel.segment_table.itemSelectionChanged.connect(self.highlight_selected_segments)
116
+ self.right_panel.btn_merge_selection.clicked.connect(
117
+ self.assign_selected_to_class
118
+ )
119
+ self.right_panel.btn_delete_selection.clicked.connect(
120
+ self.delete_selected_segments
121
+ )
122
+ self.right_panel.segment_table.itemSelectionChanged.connect(
123
+ self.highlight_selected_segments
124
+ )
95
125
  self.right_panel.segment_table.itemChanged.connect(self.handle_class_id_change)
126
+ self.right_panel.class_table.itemChanged.connect(self.handle_alias_change)
96
127
  self.right_panel.btn_reassign_classes.clicked.connect(self.reassign_class_ids)
97
- self.right_panel.class_filter_combo.currentIndexChanged.connect(self.update_segment_table)
128
+ self.right_panel.class_filter_combo.currentIndexChanged.connect(
129
+ self.update_segment_table
130
+ )
98
131
 
99
132
  self.control_panel.btn_sam_mode.clicked.connect(self.set_sam_mode)
100
133
  self.control_panel.btn_polygon_mode.clicked.connect(self.set_polygon_mode)
101
- self.control_panel.btn_selection_mode.clicked.connect(self.toggle_selection_mode)
134
+ self.control_panel.btn_selection_mode.clicked.connect(
135
+ self.toggle_selection_mode
136
+ )
102
137
  self.control_panel.btn_clear_points.clicked.connect(self.clear_all_points)
138
+ self.control_panel.btn_fit_view.clicked.connect(self.viewer.fitInView)
139
+
140
+ # **FIX:** Use QShortcut for reliable global hotkeys
141
+ next_shortcut = QShortcut(QKeySequence(Qt.Key.Key_Right), self)
142
+ next_shortcut.activated.connect(self.load_next_image)
143
+ prev_shortcut = QShortcut(QKeySequence(Qt.Key.Key_Left), self)
144
+ prev_shortcut.activated.connect(self.load_previous_image)
145
+
146
+ def show_notification(self, message, duration=3000):
147
+ self.control_panel.notification_label.setText(message)
148
+ QTimer.singleShot(
149
+ duration, lambda: self.control_panel.notification_label.clear()
150
+ )
103
151
 
104
- def _get_color_for_class(self, class_id, saturation, value):
152
+ def _get_color_for_class(self, class_id):
105
153
  if class_id is None:
106
154
  return QColor.fromHsv(0, 0, 128)
107
155
 
108
156
  hue = int((class_id * 222.4922359) % 360)
109
- color = QColor.fromHsv(hue, saturation, value)
157
+ color = QColor.fromHsv(hue, 220, 220)
110
158
 
111
159
  if not color.isValid():
112
160
  return QColor(Qt.GlobalColor.white)
@@ -122,9 +170,25 @@ class MainWindow(QMainWindow):
122
170
  self.previous_mode = self.mode
123
171
 
124
172
  self.mode = mode_name
125
- self.control_panel.mode_label.setText(f"Mode: {mode_name.replace('_', ' ').title()}")
173
+ self.control_panel.mode_label.setText(
174
+ f"Mode: {mode_name.replace('_', ' ').title()}"
175
+ )
126
176
  self.clear_all_points()
127
- self.viewer.setDragMode(self.viewer.DragMode.ScrollHandDrag if self.mode == "pan" else self.viewer.DragMode.NoDrag)
177
+
178
+ cursor_map = {
179
+ "sam_points": Qt.CursorShape.CrossCursor,
180
+ "polygon": Qt.CursorShape.CrossCursor,
181
+ "selection": Qt.CursorShape.ArrowCursor,
182
+ "edit": Qt.CursorShape.SizeAllCursor,
183
+ "pan": Qt.CursorShape.OpenHandCursor,
184
+ }
185
+ self.viewer.set_cursor(cursor_map.get(self.mode, Qt.CursorShape.ArrowCursor))
186
+
187
+ self.viewer.setDragMode(
188
+ self.viewer.DragMode.ScrollHandDrag
189
+ if self.mode == "pan"
190
+ else self.viewer.DragMode.NoDrag
191
+ )
128
192
 
129
193
  def set_sam_mode(self):
130
194
  self.set_mode("sam_points")
@@ -148,27 +212,44 @@ class MainWindow(QMainWindow):
148
212
 
149
213
  def toggle_edit_mode(self):
150
214
  selected_indices = self.get_selected_segment_indices()
151
- can_edit = any(self.segments[i].get("type") == "Polygon" for i in selected_indices)
215
+
152
216
  if self.mode == "edit":
153
217
  self.set_mode("selection", is_toggle=True)
154
- elif self.mode == "selection" and can_edit:
155
- self.set_mode("edit", is_toggle=True)
156
- self.display_all_segments()
218
+ return
219
+
220
+ if not selected_indices:
221
+ self.show_notification("Select a polygon to edit.")
222
+ return
223
+
224
+ can_edit = any(
225
+ self.segments[i].get("type") == "Polygon" for i in selected_indices
226
+ )
227
+
228
+ if not can_edit:
229
+ self.show_notification("Only polygon segments can be edited.")
230
+ return
231
+
232
+ self.set_mode("edit", is_toggle=True)
233
+ self.display_all_segments()
157
234
 
158
235
  def open_folder_dialog(self):
159
236
  folder_path = QFileDialog.getExistingDirectory(self, "Select Image Folder")
160
237
  if folder_path:
161
- self.right_panel.file_tree.setRootIndex(self.file_model.setRootPath(folder_path))
238
+ self.right_panel.file_tree.setRootIndex(
239
+ self.file_model.setRootPath(folder_path)
240
+ )
162
241
  self.viewer.setFocus()
163
242
 
164
243
  def load_selected_image(self, index):
165
- if not index.isValid():
244
+ if not index.isValid() or not self.file_model.isDir(index.parent()):
166
245
  return
167
246
 
168
247
  self.current_file_index = index
169
248
  path = self.file_model.filePath(index)
170
249
 
171
- if os.path.isfile(path) and path.lower().endswith((".png", ".jpg", ".jpeg", ".tiff", ".tif")):
250
+ if os.path.isfile(path) and path.lower().endswith(
251
+ (".png", ".jpg", ".jpeg", ".tiff", ".tif")
252
+ ):
172
253
  self.current_image_path = path
173
254
  pixmap = QPixmap(self.current_image_path)
174
255
  if not pixmap.isNull():
@@ -176,14 +257,47 @@ class MainWindow(QMainWindow):
176
257
  self.viewer.set_photo(pixmap)
177
258
  self.sam_model.set_image(self.current_image_path)
178
259
  self.load_existing_mask()
260
+ self.right_panel.file_tree.setCurrentIndex(index)
179
261
  self.viewer.setFocus()
180
262
 
263
+ def load_next_image(self):
264
+ if not self.current_file_index.isValid():
265
+ return
266
+
267
+ if self.control_panel.chk_auto_save.isChecked():
268
+ self.save_output_to_npz()
269
+
270
+ row = self.current_file_index.row()
271
+ parent = self.current_file_index.parent()
272
+ if row + 1 < self.file_model.rowCount(parent):
273
+ next_index = self.file_model.index(row + 1, 0, parent)
274
+ self.load_selected_image(next_index)
275
+
276
+ def load_previous_image(self):
277
+ if not self.current_file_index.isValid():
278
+ return
279
+
280
+ if self.control_panel.chk_auto_save.isChecked():
281
+ self.save_output_to_npz()
282
+
283
+ row = self.current_file_index.row()
284
+ parent = self.current_file_index.parent()
285
+ if row > 0:
286
+ prev_index = self.file_model.index(row - 1, 0, parent)
287
+ self.load_selected_image(prev_index)
288
+
181
289
  def reset_state(self):
182
290
  self.clear_all_points()
291
+ # Preserve aliases between images in the same session
292
+ # self.class_aliases.clear()
183
293
  self.segments.clear()
184
294
  self.next_class_id = 0
185
295
  self.update_all_lists()
186
- items_to_remove = [item for item in self.viewer.scene().items() if item is not self.viewer._pixmap_item]
296
+ items_to_remove = [
297
+ item
298
+ for item in self.viewer.scene().items()
299
+ if item is not self.viewer._pixmap_item
300
+ ]
187
301
  for item in items_to_remove:
188
302
  self.viewer.scene().removeItem(item)
189
303
  self.segment_items.clear()
@@ -193,14 +307,32 @@ class MainWindow(QMainWindow):
193
307
  key, mods = event.key(), event.modifiers()
194
308
  if event.isAutoRepeat():
195
309
  return
310
+
311
+ pan_multiplier = 5.0 if (mods & Qt.KeyboardModifier.ShiftModifier) else 2
312
+
196
313
  if key == Qt.Key.Key_W:
197
- self.viewer.verticalScrollBar().setValue(self.viewer.verticalScrollBar().value() - int(self.viewer.height() * 0.1))
198
- elif key == Qt.Key.Key_S and not mods:
199
- self.viewer.verticalScrollBar().setValue(self.viewer.verticalScrollBar().value() + int(self.viewer.height() * 0.1))
314
+ amount = int(self.viewer.height() * 0.1 * pan_multiplier)
315
+ self.viewer.verticalScrollBar().setValue(
316
+ self.viewer.verticalScrollBar().value() - amount
317
+ )
318
+ elif key == Qt.Key.Key_S:
319
+ amount = int(self.viewer.height() * 0.1 * pan_multiplier)
320
+ self.viewer.verticalScrollBar().setValue(
321
+ self.viewer.verticalScrollBar().value() + amount
322
+ )
200
323
  elif key == Qt.Key.Key_A and not (mods & Qt.KeyboardModifier.ControlModifier):
201
- self.viewer.horizontalScrollBar().setValue(self.viewer.horizontalScrollBar().value() - int(self.viewer.width() * 0.1))
324
+ amount = int(self.viewer.width() * 0.1 * pan_multiplier)
325
+ self.viewer.horizontalScrollBar().setValue(
326
+ self.viewer.horizontalScrollBar().value() - amount
327
+ )
202
328
  elif key == Qt.Key.Key_D:
203
- self.viewer.horizontalScrollBar().setValue(self.viewer.horizontalScrollBar().value() + int(self.viewer.width() * 0.1))
329
+ amount = int(self.viewer.width() * 0.1 * pan_multiplier)
330
+ self.viewer.horizontalScrollBar().setValue(
331
+ self.viewer.horizontalScrollBar().value() + amount
332
+ )
333
+ elif key == Qt.Key.Key_Period:
334
+ self.viewer.fitInView()
335
+ # Other keybindings
204
336
  elif key == Qt.Key.Key_1:
205
337
  self.set_sam_mode()
206
338
  elif key == Qt.Key.Key_2:
@@ -223,16 +355,47 @@ class MainWindow(QMainWindow):
223
355
  elif key == Qt.Key.Key_A and mods == Qt.KeyboardModifier.ControlModifier:
224
356
  self.right_panel.segment_table.selectAll()
225
357
  elif key == Qt.Key.Key_Space:
226
- self.save_current_segment()
358
+ if self.mode == "polygon" and self.polygon_points:
359
+ self.finalize_polygon()
360
+ else:
361
+ self.save_current_segment()
227
362
  elif key == Qt.Key.Key_Return or key == Qt.Key.Key_Enter:
228
- self.save_output_to_npz()
363
+ if self.mode == "polygon" and self.polygon_points:
364
+ self.finalize_polygon()
365
+ else:
366
+ self.save_output_to_npz()
367
+ elif (
368
+ key == Qt.Key.Key_Equal or key == Qt.Key.Key_Plus
369
+ ) and mods == Qt.KeyboardModifier.ControlModifier:
370
+ self.point_radius = min(20, self.point_radius + self._original_point_radius)
371
+ self.line_thickness = min(
372
+ 20, self.line_thickness + self._original_line_thickness
373
+ )
374
+ self.display_all_segments()
375
+ self.clear_all_points()
376
+ elif key == Qt.Key.Key_Minus and mods == Qt.KeyboardModifier.ControlModifier:
377
+ self.point_radius = max(
378
+ 0.3, self.point_radius - self._original_point_radius
379
+ )
380
+ self.line_thickness = max(
381
+ 0.5, self.line_thickness - self._original_line_thickness
382
+ )
383
+ self.display_all_segments()
384
+ self.clear_all_points()
229
385
 
230
386
  def scene_mouse_press(self, event):
231
387
  self._original_mouse_press(event)
232
388
  if event.isAccepted():
233
389
  return
390
+
391
+ if self.mode == "pan":
392
+ self.viewer.set_cursor(Qt.CursorShape.ClosedHandCursor)
393
+
234
394
  pos = event.scenePos()
235
- if self.viewer._pixmap_item.pixmap().isNull() or not self.viewer._pixmap_item.pixmap().rect().contains(pos.toPoint()):
395
+ if (
396
+ self.viewer._pixmap_item.pixmap().isNull()
397
+ or not self.viewer._pixmap_item.pixmap().rect().contains(pos.toPoint())
398
+ ):
236
399
  return
237
400
  if self.mode == "sam_points":
238
401
  if event.button() == Qt.MouseButton.LeftButton:
@@ -252,7 +415,9 @@ class MainWindow(QMainWindow):
252
415
  self.is_dragging_polygon = True
253
416
  selected_indices = self.get_selected_segment_indices()
254
417
  self.drag_initial_vertices = {
255
- i: list(self.segments[i]["vertices"]) for i in selected_indices if self.segments[i].get("type") == "Polygon"
418
+ i: list(self.segments[i]["vertices"])
419
+ for i in selected_indices
420
+ if self.segments[i].get("type") == "Polygon"
256
421
  }
257
422
 
258
423
  def scene_mouse_move(self, event):
@@ -260,13 +425,22 @@ class MainWindow(QMainWindow):
260
425
  if self.mode == "edit" and self.is_dragging_polygon:
261
426
  delta = pos - self.drag_start_pos
262
427
  for i, initial_verts in self.drag_initial_vertices.items():
263
- self.segments[i]["vertices"] = [QPointF(v.x() + delta.x(), v.y() + delta.y()) for v in initial_verts]
428
+ self.segments[i]["vertices"] = [
429
+ QPointF(v.x() + delta.x(), v.y() + delta.y()) for v in initial_verts
430
+ ]
264
431
  self.update_polygon_visuals(i)
265
432
  elif self.mode == "polygon" and self.polygon_points:
266
433
  if self.rubber_band_line is None:
267
434
  self.rubber_band_line = QGraphicsLineItem()
268
- self.rubber_band_line.setPen(QPen(Qt.GlobalColor.white, 2, Qt.PenStyle.DotLine))
435
+
436
+ line_color = QColor(Qt.GlobalColor.white)
437
+ line_color.setAlpha(150)
438
+
439
+ self.rubber_band_line.setPen(
440
+ QPen(line_color, self.line_thickness, Qt.PenStyle.DotLine)
441
+ )
269
442
  self.viewer.scene().addItem(self.rubber_band_line)
443
+
270
444
  self.rubber_band_line.setLine(
271
445
  self.polygon_points[-1].x(),
272
446
  self.polygon_points[-1].y(),
@@ -278,6 +452,9 @@ class MainWindow(QMainWindow):
278
452
  self._original_mouse_move(event)
279
453
 
280
454
  def scene_mouse_release(self, event):
455
+ if self.mode == "pan":
456
+ self.viewer.set_cursor(Qt.CursorShape.OpenHandCursor)
457
+
281
458
  if self.mode == "edit" and self.is_dragging_polygon:
282
459
  self.is_dragging_polygon = False
283
460
  self.drag_initial_vertices.clear()
@@ -286,20 +463,55 @@ class MainWindow(QMainWindow):
286
463
  def undo_last_action(self):
287
464
  if self.mode == "polygon" and self.polygon_points:
288
465
  self.polygon_points.pop()
289
- if self.polygon_preview_items:
290
- self.viewer.scene().removeItem(self.polygon_preview_items.pop())
466
+
467
+ for item in self.polygon_preview_items:
468
+ if item.scene():
469
+ self.viewer.scene().removeItem(item)
470
+ self.polygon_preview_items.clear()
471
+
472
+ for point in self.polygon_points:
473
+ point_diameter = self.point_radius * 2
474
+ point_color = QColor(Qt.GlobalColor.blue)
475
+ point_color.setAlpha(150)
476
+ dot = QGraphicsEllipseItem(
477
+ point.x() - self.point_radius,
478
+ point.y() - self.point_radius,
479
+ point_diameter,
480
+ point_diameter,
481
+ )
482
+ dot.setBrush(QBrush(point_color))
483
+ dot.setPen(QPen(Qt.GlobalColor.transparent))
484
+ self.viewer.scene().addItem(dot)
485
+ self.polygon_preview_items.append(dot)
486
+
291
487
  self.draw_polygon_preview()
488
+
292
489
  elif self.mode == "sam_points" and self.point_items:
293
490
  item_to_remove = self.point_items.pop()
294
- point_pos = item_to_remove.rect().topLeft() + QPointF(4, 4)
491
+ point_pos = item_to_remove.rect().topLeft() + QPointF(
492
+ self.point_radius, self.point_radius
493
+ )
295
494
  point_coords = [int(point_pos.x()), int(point_pos.y())]
495
+
296
496
  if point_coords in self.positive_points:
297
497
  self.positive_points.remove(point_coords)
298
498
  elif point_coords in self.negative_points:
299
499
  self.negative_points.remove(point_coords)
500
+
300
501
  self.viewer.scene().removeItem(item_to_remove)
301
502
  self.update_segmentation()
302
503
 
504
+ def _update_next_class_id(self):
505
+ all_ids = {
506
+ seg.get("class_id")
507
+ for seg in self.segments
508
+ if seg.get("class_id") is not None
509
+ }
510
+ if not all_ids:
511
+ self.next_class_id = 0
512
+ else:
513
+ self.next_class_id = max(all_ids) + 1
514
+
303
515
  def finalize_polygon(self):
304
516
  if len(self.polygon_points) < 3:
305
517
  return
@@ -314,7 +526,7 @@ class MainWindow(QMainWindow):
314
526
  "class_id": self.next_class_id,
315
527
  }
316
528
  )
317
- self.next_class_id += 1
529
+ self._update_next_class_id()
318
530
  self.polygon_points.clear()
319
531
  for item in self.polygon_preview_items:
320
532
  self.viewer.scene().removeItem(item)
@@ -325,14 +537,25 @@ class MainWindow(QMainWindow):
325
537
  x, y = int(pos.x()), int(pos.y())
326
538
  for i in range(len(self.segments) - 1, -1, -1):
327
539
  seg = self.segments[i]
328
- mask = self.rasterize_polygon(seg["vertices"]) if seg["type"] == "Polygon" else seg.get("mask")
329
- if mask is not None and y < mask.shape[0] and x < mask.shape[1] and mask[y, x]:
540
+ mask = (
541
+ self.rasterize_polygon(seg["vertices"])
542
+ if seg["type"] == "Polygon"
543
+ else seg.get("mask")
544
+ )
545
+ if (
546
+ mask is not None
547
+ and y < mask.shape[0]
548
+ and x < mask.shape[1]
549
+ and mask[y, x]
550
+ ):
330
551
  for j in range(self.right_panel.segment_table.rowCount()):
331
552
  item = self.right_panel.segment_table.item(j, 0)
332
553
  if item and item.data(Qt.ItemDataRole.UserRole) == i:
333
554
  table = self.right_panel.segment_table
334
555
  is_selected = table.item(j, 0).isSelected()
335
- range_to_select = QTableWidgetSelectionRange(j, 0, j, table.columnCount() - 1)
556
+ range_to_select = QTableWidgetSelectionRange(
557
+ j, 0, j, table.columnCount() - 1
558
+ )
336
559
  table.setRangeSelected(range_to_select, not is_selected)
337
560
  return
338
561
  self.viewer.setFocus()
@@ -342,7 +565,11 @@ class MainWindow(QMainWindow):
342
565
  if not selected_indices:
343
566
  return
344
567
 
345
- existing_class_ids = [self.segments[i]["class_id"] for i in selected_indices if self.segments[i].get("class_id") is not None]
568
+ existing_class_ids = [
569
+ self.segments[i]["class_id"]
570
+ for i in selected_indices
571
+ if self.segments[i].get("class_id") is not None
572
+ ]
346
573
 
347
574
  if existing_class_ids:
348
575
  target_class_id = min(existing_class_ids)
@@ -352,6 +579,7 @@ class MainWindow(QMainWindow):
352
579
  for i in selected_indices:
353
580
  self.segments[i]["class_id"] = target_class_id
354
581
 
582
+ self._update_next_class_id()
355
583
  self.update_all_lists()
356
584
  self.right_panel.segment_table.clearSelection()
357
585
  self.viewer.setFocus()
@@ -378,31 +606,62 @@ class MainWindow(QMainWindow):
378
606
  for i, seg_dict in enumerate(self.segments):
379
607
  self.segment_items[i] = []
380
608
  class_id = seg_dict.get("class_id")
381
- base_color = self._get_color_for_class(class_id, saturation=220, value=220)
609
+
610
+ base_color = self._get_color_for_class(class_id)
382
611
 
383
612
  if seg_dict["type"] == "Polygon":
384
613
  poly_item = HoverablePolygonItem(QPolygonF(seg_dict["vertices"]))
385
- default_brush = QBrush(QColor(base_color.red(), base_color.green(), base_color.blue(), 70))
386
- hover_brush = QBrush(QColor(base_color.red(), base_color.green(), base_color.blue(), 170))
614
+ default_brush = QBrush(
615
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 70)
616
+ )
617
+ hover_brush = QBrush(
618
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 170)
619
+ )
387
620
  poly_item.set_brushes(default_brush, hover_brush)
388
621
  poly_item.setPen(QPen(Qt.GlobalColor.transparent))
389
622
  self.viewer.scene().addItem(poly_item)
390
623
  self.segment_items[i].append(poly_item)
624
+
625
+ base_color.setAlpha(150)
391
626
  vertex_color = QBrush(base_color)
627
+ point_diameter = self.point_radius * 2
392
628
  for v in seg_dict["vertices"]:
393
- dot = QGraphicsEllipseItem(v.x() - 3, v.y() - 3, 6, 6)
629
+ dot = QGraphicsEllipseItem(
630
+ v.x() - self.point_radius,
631
+ v.y() - self.point_radius,
632
+ point_diameter,
633
+ point_diameter,
634
+ )
394
635
  dot.setBrush(vertex_color)
636
+ dot.setPen(QPen(Qt.GlobalColor.transparent))
395
637
  self.viewer.scene().addItem(dot)
396
638
  self.segment_items[i].append(dot)
397
639
  if self.mode == "edit" and i in selected_indices:
640
+ handle_diameter = self.point_radius * 2
398
641
  for idx, v in enumerate(seg_dict["vertices"]):
399
- vertex_item = EditableVertexItem(self, i, idx, -4, -4, 8, 8)
642
+ vertex_item = EditableVertexItem(
643
+ self,
644
+ i,
645
+ idx,
646
+ -handle_diameter / 2,
647
+ -handle_diameter / 2,
648
+ handle_diameter,
649
+ handle_diameter,
650
+ )
400
651
  vertex_item.setPos(v)
401
652
  self.viewer.scene().addItem(vertex_item)
402
653
  self.segment_items[i].append(vertex_item)
403
654
  elif seg_dict.get("mask") is not None:
404
- pixmap = mask_to_pixmap(seg_dict["mask"], base_color.getRgb()[:3])
405
- pixmap_item = self.viewer.scene().addPixmap(pixmap)
655
+ default_pixmap = mask_to_pixmap(
656
+ seg_dict["mask"], base_color.getRgb()[:3], alpha=70
657
+ )
658
+ hover_pixmap = mask_to_pixmap(
659
+ seg_dict["mask"], base_color.getRgb()[:3], alpha=170
660
+ )
661
+
662
+ pixmap_item = HoverablePixmapItem()
663
+ pixmap_item.set_pixmaps(default_pixmap, hover_pixmap)
664
+ self.viewer.scene().addItem(pixmap_item)
406
665
  pixmap_item.setZValue(i + 1)
407
666
  self.segment_items[i].append(pixmap_item)
408
667
  self.highlight_selected_segments()
@@ -426,7 +685,11 @@ class MainWindow(QMainWindow):
426
685
  selected_indices = self.get_selected_segment_indices()
427
686
  for i in selected_indices:
428
687
  seg = self.segments[i]
429
- mask = self.rasterize_polygon(seg["vertices"]) if seg["type"] == "Polygon" else seg.get("mask")
688
+ mask = (
689
+ self.rasterize_polygon(seg["vertices"])
690
+ if seg["type"] == "Polygon"
691
+ else seg.get("mask")
692
+ )
430
693
  if mask is not None:
431
694
  pixmap = mask_to_pixmap(mask, (255, 255, 255))
432
695
  highlight_item = self.viewer.scene().addPixmap(pixmap)
@@ -434,9 +697,9 @@ class MainWindow(QMainWindow):
434
697
  self.highlight_items.append(highlight_item)
435
698
 
436
699
  def update_all_lists(self):
700
+ self.update_class_list() # Must be before filter combo
437
701
  self.update_class_filter_combo()
438
702
  self.update_segment_table()
439
- self.update_class_list()
440
703
  self.display_all_segments()
441
704
 
442
705
  def update_segment_table(self):
@@ -450,7 +713,7 @@ class MainWindow(QMainWindow):
450
713
  filter_class_id = -1
451
714
  if not show_all:
452
715
  try:
453
- filter_class_id = int(filter_text.split(" ")[1])
716
+ filter_class_id = int(filter_text.split("(ID: ")[1][:-1])
454
717
  except (ValueError, IndexError):
455
718
  pass
456
719
 
@@ -463,7 +726,7 @@ class MainWindow(QMainWindow):
463
726
 
464
727
  for row, (original_index, seg) in enumerate(display_segments):
465
728
  class_id = seg.get("class_id")
466
- color = self._get_color_for_class(class_id, saturation=180, value=200)
729
+ color = self._get_color_for_class(class_id)
467
730
 
468
731
  class_id_str = str(class_id) if class_id is not None else "N/A"
469
732
  index_item = NumericTableWidgetItem(str(original_index + 1))
@@ -478,7 +741,7 @@ class MainWindow(QMainWindow):
478
741
  table.setItem(row, 1, class_item)
479
742
  table.setItem(row, 2, type_item)
480
743
 
481
- for col in range(3):
744
+ for col in range(table.columnCount()):
482
745
  if table.item(row, col):
483
746
  table.item(row, col).setBackground(QBrush(color))
484
747
 
@@ -495,30 +758,70 @@ class MainWindow(QMainWindow):
495
758
  def update_class_list(self):
496
759
  class_table = self.right_panel.class_table
497
760
  class_table.blockSignals(True)
761
+
762
+ # Preserve existing aliases during update
763
+ current_aliases = {}
764
+ for row in range(class_table.rowCount()):
765
+ try:
766
+ alias = class_table.item(row, 0).text()
767
+ cid = int(class_table.item(row, 1).text())
768
+ current_aliases[cid] = alias
769
+ except (AttributeError, ValueError):
770
+ continue
771
+ self.class_aliases.update(current_aliases)
772
+
498
773
  class_table.clearContents()
499
774
 
500
- unique_class_ids = sorted(list({seg.get("class_id") for seg in self.segments if seg.get("class_id") is not None}))
775
+ unique_class_ids = sorted(
776
+ list(
777
+ {
778
+ seg.get("class_id")
779
+ for seg in self.segments
780
+ if seg.get("class_id") is not None
781
+ }
782
+ )
783
+ )
501
784
  class_table.setRowCount(len(unique_class_ids))
502
785
 
503
786
  for row, cid in enumerate(unique_class_ids):
504
- item = QTableWidgetItem(str(cid))
505
- item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
787
+ alias = self.class_aliases.get(cid, str(cid))
788
+ alias_item = QTableWidgetItem(alias)
789
+ id_item = QTableWidgetItem(str(cid))
506
790
 
507
- color = self._get_color_for_class(cid, saturation=180, value=200)
791
+ id_item.setFlags(id_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
508
792
 
509
- item.setBackground(QBrush(color))
510
- class_table.setItem(row, 0, item)
793
+ color = self._get_color_for_class(cid)
794
+ alias_item.setBackground(QBrush(color))
795
+ id_item.setBackground(QBrush(color))
796
+
797
+ class_table.setItem(row, 0, alias_item)
798
+ class_table.setItem(row, 1, id_item)
511
799
 
512
800
  class_table.blockSignals(False)
513
801
 
514
802
  def update_class_filter_combo(self):
515
803
  combo = self.right_panel.class_filter_combo
516
- unique_class_ids = sorted(list({seg.get("class_id") for seg in self.segments if seg.get("class_id") is not None}))
804
+ unique_class_ids = sorted(
805
+ list(
806
+ {
807
+ seg.get("class_id")
808
+ for seg in self.segments
809
+ if seg.get("class_id") is not None
810
+ }
811
+ )
812
+ )
813
+
517
814
  current_selection = combo.currentText()
518
815
  combo.blockSignals(True)
519
816
  combo.clear()
520
817
  combo.addItem("All Classes")
521
- combo.addItems([f"Class {cid}" for cid in unique_class_ids])
818
+ combo.addItems(
819
+ [
820
+ f"{self.class_aliases.get(cid, cid)} (ID: {cid})"
821
+ for cid in unique_class_ids
822
+ ]
823
+ )
824
+
522
825
  if combo.findText(current_selection) > -1:
523
826
  combo.setCurrentText(current_selection)
524
827
  else:
@@ -527,20 +830,54 @@ class MainWindow(QMainWindow):
527
830
 
528
831
  def reassign_class_ids(self):
529
832
  class_table = self.right_panel.class_table
530
- ordered_ids = [
531
- int(class_table.item(row, 0).text()) for row in range(class_table.rowCount()) if class_table.item(row, 0) is not None
532
- ]
833
+
834
+ ordered_ids = []
835
+ for row in range(class_table.rowCount()):
836
+ id_item = class_table.item(row, 1)
837
+ if id_item and id_item.text():
838
+ try:
839
+ ordered_ids.append(int(id_item.text()))
840
+ except ValueError:
841
+ continue
842
+
533
843
  id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
844
+
534
845
  for seg in self.segments:
535
846
  old_id = seg.get("class_id")
536
847
  if old_id in id_map:
537
848
  seg["class_id"] = id_map[old_id]
538
- self.next_class_id = len(ordered_ids)
849
+
850
+ new_aliases = {
851
+ id_map[old_id]: self.class_aliases.get(old_id, str(old_id))
852
+ for old_id in ordered_ids
853
+ if old_id in self.class_aliases
854
+ }
855
+ self.class_aliases = new_aliases
856
+
857
+ self._update_next_class_id()
539
858
  self.update_all_lists()
540
859
  self.viewer.setFocus()
541
860
 
861
+ def handle_alias_change(self, item):
862
+ if item.column() != 0: # Alias column
863
+ return
864
+
865
+ class_table = self.right_panel.class_table
866
+ class_table.blockSignals(True)
867
+
868
+ id_item = class_table.item(item.row(), 1)
869
+ if id_item:
870
+ try:
871
+ class_id = int(id_item.text())
872
+ self.class_aliases[class_id] = item.text()
873
+ except (ValueError, AttributeError):
874
+ pass # Ignore if ID item is not valid
875
+
876
+ class_table.blockSignals(False)
877
+ self.update_class_filter_combo() # Refresh filter to show new alias
878
+
542
879
  def handle_class_id_change(self, item):
543
- if item.column() != 1:
880
+ if item.column() != 1: # Class ID column in segment table
544
881
  return
545
882
  table = self.right_panel.segment_table
546
883
  index_item = table.item(item.row(), 0)
@@ -559,14 +896,15 @@ class MainWindow(QMainWindow):
559
896
  raise IndexError("Invalid segment index found in table.")
560
897
 
561
898
  self.segments[original_index]["class_id"] = new_class_id
562
- if new_class_id >= self.next_class_id:
563
- self.next_class_id = new_class_id + 1
899
+ self._update_next_class_id()
564
900
  self.update_all_lists()
565
901
  except (ValueError, TypeError, AttributeError, IndexError) as e:
566
902
  original_index = index_item.data(Qt.ItemDataRole.UserRole)
567
903
  if original_index is not None and original_index < len(self.segments):
568
904
  original_class_id = self.segments[original_index].get("class_id")
569
- item.setText(str(original_class_id) if original_class_id is not None else "N/A")
905
+ item.setText(
906
+ str(original_class_id) if original_class_id is not None else "N/A"
907
+ )
570
908
  finally:
571
909
  table.blockSignals(False)
572
910
  self.viewer.setFocus()
@@ -575,7 +913,11 @@ class MainWindow(QMainWindow):
575
913
  table = self.right_panel.segment_table
576
914
  selected_items = table.selectedItems()
577
915
  selected_rows = sorted(list({item.row() for item in selected_items}))
578
- return [table.item(row, 0).data(Qt.ItemDataRole.UserRole) for row in selected_rows if table.item(row, 0)]
916
+ return [
917
+ table.item(row, 0).data(Qt.ItemDataRole.UserRole)
918
+ for row in selected_rows
919
+ if table.item(row, 0)
920
+ ]
579
921
 
580
922
  def save_output_to_npz(self):
581
923
  if not self.segments or not self.current_image_path:
@@ -588,14 +930,21 @@ class MainWindow(QMainWindow):
588
930
  self.viewer._pixmap_item.pixmap().height(),
589
931
  self.viewer._pixmap_item.pixmap().width(),
590
932
  )
591
- unique_class_ids = sorted(list({seg["class_id"] for seg in self.segments if seg.get("class_id") is not None}))
592
- if not unique_class_ids:
593
- self.right_panel.status_label.setText("Save failed: No classes.")
933
+
934
+ class_table = self.right_panel.class_table
935
+ ordered_ids = [
936
+ int(class_table.item(row, 1).text())
937
+ for row in range(class_table.rowCount())
938
+ if class_table.item(row, 1) is not None
939
+ ]
940
+
941
+ if not ordered_ids:
942
+ self.right_panel.status_label.setText("Save failed: No classes defined.")
594
943
  QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
595
944
  return
596
945
 
597
- id_map = {old_id: new_id for new_id, old_id in enumerate(unique_class_ids)}
598
- num_final_classes = len(unique_class_ids)
946
+ id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
947
+ num_final_classes = len(ordered_ids)
599
948
  final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
600
949
 
601
950
  for seg in self.segments:
@@ -603,12 +952,20 @@ class MainWindow(QMainWindow):
603
952
  if class_id not in id_map:
604
953
  continue
605
954
  new_channel_idx = id_map[class_id]
606
- mask = self.rasterize_polygon(seg["vertices"]) if seg["type"] == "Polygon" else seg.get("mask")
955
+ mask = (
956
+ self.rasterize_polygon(seg["vertices"])
957
+ if seg["type"] == "Polygon"
958
+ else seg.get("mask")
959
+ )
607
960
  if mask is not None:
608
- final_mask_tensor[:, :, new_channel_idx] = np.logical_or(final_mask_tensor[:, :, new_channel_idx], mask)
961
+ final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
962
+ final_mask_tensor[:, :, new_channel_idx], mask
963
+ )
609
964
 
610
965
  np.savez_compressed(output_path, mask=final_mask_tensor.astype(np.uint8))
611
- self.file_model.setRootPath(self.file_model.rootPath())
966
+
967
+ self.file_model.set_highlighted_path(output_path)
968
+ QTimer.singleShot(1500, lambda: self.file_model.set_highlighted_path(None))
612
969
 
613
970
  self.right_panel.status_label.setText("Saved!")
614
971
  self.generate_yolo_annotations(npz_file_path=output_path)
@@ -629,7 +986,9 @@ class MainWindow(QMainWindow):
629
986
 
630
987
  for channel in range(num_channels):
631
988
  single_channel_image = img[:, :, channel]
632
- contours, _ = cv2.findContours(single_channel_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
989
+ contours, _ = cv2.findContours(
990
+ single_channel_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
991
+ )
633
992
 
634
993
  class_id = channel # Use the channel index as the class ID
635
994
 
@@ -651,7 +1010,11 @@ class MainWindow(QMainWindow):
651
1010
  file.write(annotation + "\n")
652
1011
 
653
1012
  def save_current_segment(self):
654
- if self.mode != "sam_points" or not hasattr(self, "preview_mask_item") or not self.preview_mask_item:
1013
+ if (
1014
+ self.mode != "sam_points"
1015
+ or not hasattr(self, "preview_mask_item")
1016
+ or not self.preview_mask_item
1017
+ ):
655
1018
  return
656
1019
  mask = self.sam_model.predict(self.positive_points, self.negative_points)
657
1020
  if mask is not None:
@@ -663,7 +1026,7 @@ class MainWindow(QMainWindow):
663
1026
  "class_id": self.next_class_id,
664
1027
  }
665
1028
  )
666
- self.next_class_id += 1
1029
+ self._update_next_class_id()
667
1030
  self.clear_all_points()
668
1031
  self.update_all_lists()
669
1032
 
@@ -673,6 +1036,7 @@ class MainWindow(QMainWindow):
673
1036
  return
674
1037
  for i in sorted(selected_indices, reverse=True):
675
1038
  del self.segments[i]
1039
+ self._update_next_class_id()
676
1040
  self.update_all_lists()
677
1041
  self.viewer.setFocus()
678
1042
 
@@ -698,16 +1062,27 @@ class MainWindow(QMainWindow):
698
1062
  "class_id": i,
699
1063
  }
700
1064
  )
701
- self.next_class_id = num_classes
1065
+ self._update_next_class_id()
702
1066
  self.update_all_lists()
703
1067
 
704
1068
  def add_point(self, pos, positive):
705
1069
  point_list = self.positive_points if positive else self.negative_points
706
1070
  point_list.append([int(pos.x()), int(pos.y())])
707
- color = Qt.GlobalColor.green if positive else Qt.GlobalColor.red
708
- point_item = QGraphicsEllipseItem(pos.x() - 4, pos.y() - 4, 8, 8)
709
- point_item.setBrush(QBrush(color))
710
- point_item.setPen(QPen(Qt.GlobalColor.white))
1071
+
1072
+ point_color = (
1073
+ QColor(Qt.GlobalColor.green) if positive else QColor(Qt.GlobalColor.red)
1074
+ )
1075
+ point_color.setAlpha(150)
1076
+
1077
+ point_diameter = self.point_radius * 2
1078
+ point_item = QGraphicsEllipseItem(
1079
+ pos.x() - self.point_radius,
1080
+ pos.y() - self.point_radius,
1081
+ point_diameter,
1082
+ point_diameter,
1083
+ )
1084
+ point_item.setBrush(QBrush(point_color))
1085
+ point_item.setPen(QPen(Qt.GlobalColor.transparent))
711
1086
  self.viewer.scene().addItem(point_item)
712
1087
  self.point_items.append(point_item)
713
1088
 
@@ -740,26 +1115,45 @@ class MainWindow(QMainWindow):
740
1115
  self.preview_mask_item = None
741
1116
 
742
1117
  def handle_polygon_click(self, pos):
743
- if self.polygon_points and (((pos.x() - self.polygon_points[0].x()) ** 2 + (pos.y() - self.polygon_points[0].y()) ** 2) < 25):
1118
+ if self.polygon_points and (
1119
+ (
1120
+ (pos.x() - self.polygon_points[0].x()) ** 2
1121
+ + (pos.y() - self.polygon_points[0].y()) ** 2
1122
+ )
1123
+ < 4 # pixel distance threshold squared
1124
+ ):
744
1125
  if len(self.polygon_points) > 2:
745
1126
  self.finalize_polygon()
746
1127
  return
747
1128
  self.polygon_points.append(pos)
748
- dot = QGraphicsEllipseItem(pos.x() - 2, pos.y() - 2, 4, 4)
749
- dot.setBrush(QBrush(Qt.GlobalColor.blue))
750
- dot.setPen(QPen(Qt.GlobalColor.cyan))
1129
+ point_diameter = self.point_radius * 2
1130
+
1131
+ point_color = QColor(Qt.GlobalColor.blue)
1132
+ point_color.setAlpha(150)
1133
+
1134
+ dot = QGraphicsEllipseItem(
1135
+ pos.x() - self.point_radius,
1136
+ pos.y() - self.point_radius,
1137
+ point_diameter,
1138
+ point_diameter,
1139
+ )
1140
+ dot.setBrush(QBrush(point_color))
1141
+ dot.setPen(QPen(Qt.GlobalColor.transparent))
751
1142
  self.viewer.scene().addItem(dot)
752
1143
  self.polygon_preview_items.append(dot)
753
1144
  self.draw_polygon_preview()
754
1145
 
755
1146
  def draw_polygon_preview(self):
756
- if self.rubber_band_line:
757
- self.viewer.scene().removeItem(self.rubber_band_line)
758
- self.rubber_band_line = None
1147
+ # Clean up old preview lines/polygons
759
1148
  for item in self.polygon_preview_items:
760
1149
  if not isinstance(item, QGraphicsEllipseItem):
761
- self.viewer.scene().removeItem(item)
762
- self.polygon_preview_items = [item for item in self.polygon_preview_items if isinstance(item, QGraphicsEllipseItem)]
1150
+ if item.scene():
1151
+ self.viewer.scene().removeItem(item)
1152
+ self.polygon_preview_items = [
1153
+ item
1154
+ for item in self.polygon_preview_items
1155
+ if isinstance(item, QGraphicsEllipseItem)
1156
+ ]
763
1157
 
764
1158
  if len(self.polygon_points) > 2:
765
1159
  preview_poly = QGraphicsPolygonItem(QPolygonF(self.polygon_points))
@@ -769,6 +1163,8 @@ class MainWindow(QMainWindow):
769
1163
  self.polygon_preview_items.append(preview_poly)
770
1164
 
771
1165
  if len(self.polygon_points) > 1:
1166
+ line_color = QColor(Qt.GlobalColor.cyan)
1167
+ line_color.setAlpha(150)
772
1168
  for i in range(len(self.polygon_points) - 1):
773
1169
  line = QGraphicsLineItem(
774
1170
  self.polygon_points[i].x(),
@@ -776,7 +1172,7 @@ class MainWindow(QMainWindow):
776
1172
  self.polygon_points[i + 1].x(),
777
1173
  self.polygon_points[i + 1].y(),
778
1174
  )
779
- line.setPen(QPen(Qt.GlobalColor.cyan, 2))
1175
+ line.setPen(QPen(line_color, self.line_thickness))
780
1176
  self.viewer.scene().addItem(line)
781
1177
  self.polygon_preview_items.append(line)
782
1178