lazylabel-gui 1.0.9__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. lazylabel/__init__.py +9 -0
  2. lazylabel/config/__init__.py +7 -0
  3. lazylabel/config/hotkeys.py +169 -0
  4. lazylabel/config/paths.py +41 -0
  5. lazylabel/config/settings.py +66 -0
  6. lazylabel/core/__init__.py +7 -0
  7. lazylabel/core/file_manager.py +106 -0
  8. lazylabel/core/model_manager.py +94 -0
  9. lazylabel/core/segment_manager.py +140 -0
  10. lazylabel/main.py +10 -1266
  11. lazylabel/models/__init__.py +5 -0
  12. lazylabel/models/sam_model.py +154 -0
  13. lazylabel/ui/__init__.py +8 -0
  14. lazylabel/ui/control_panel.py +220 -0
  15. lazylabel/{editable_vertex.py → ui/editable_vertex.py} +25 -3
  16. lazylabel/ui/hotkey_dialog.py +384 -0
  17. lazylabel/{hoverable_polygon_item.py → ui/hoverable_polygon_item.py} +17 -1
  18. lazylabel/ui/main_window.py +1264 -0
  19. lazylabel/ui/right_panel.py +239 -0
  20. lazylabel/ui/widgets/__init__.py +7 -0
  21. lazylabel/ui/widgets/adjustments_widget.py +107 -0
  22. lazylabel/ui/widgets/model_selection_widget.py +94 -0
  23. lazylabel/ui/widgets/settings_widget.py +106 -0
  24. lazylabel/utils/__init__.py +6 -0
  25. lazylabel/{custom_file_system_model.py → utils/custom_file_system_model.py} +9 -3
  26. {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.0.dist-info}/METADATA +61 -11
  27. lazylabel_gui-1.1.0.dist-info/RECORD +36 -0
  28. lazylabel/controls.py +0 -265
  29. lazylabel/sam_model.py +0 -70
  30. lazylabel_gui-1.0.9.dist-info/RECORD +0 -17
  31. /lazylabel/{hoverable_pixelmap_item.py → ui/hoverable_pixelmap_item.py} +0 -0
  32. /lazylabel/{numeric_table_widget_item.py → ui/numeric_table_widget_item.py} +0 -0
  33. /lazylabel/{photo_viewer.py → ui/photo_viewer.py} +0 -0
  34. /lazylabel/{reorderable_class_table.py → ui/reorderable_class_table.py} +0 -0
  35. /lazylabel/{utils.py → utils/utils.py} +0 -0
  36. {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.0.dist-info}/WHEEL +0 -0
  37. {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.0.dist-info}/entry_points.txt +0 -0
  38. {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.0.dist-info}/licenses/LICENSE +0 -0
  39. {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.0.dist-info}/top_level.txt +0 -0
lazylabel/main.py CHANGED
@@ -1,1278 +1,22 @@
1
+ """Main entry point for LazyLabel application."""
2
+
1
3
  import sys
2
- import os
3
- import numpy as np
4
4
  import qdarktheme
5
- import cv2
6
- import json
7
- from PyQt6.QtWidgets import (
8
- QApplication,
9
- QMainWindow,
10
- QWidget,
11
- QHBoxLayout,
12
- QFileDialog,
13
- QGraphicsItem,
14
- QGraphicsEllipseItem,
15
- QGraphicsLineItem,
16
- QTableWidgetItem,
17
- QGraphicsPolygonItem,
18
- QTableWidgetSelectionRange,
19
- QSpacerItem,
20
- QHeaderView,
21
- )
22
- from PyQt6.QtGui import (
23
- QPixmap,
24
- QColor,
25
- QPen,
26
- QBrush,
27
- QPolygonF,
28
- QIcon,
29
- QCursor,
30
- QKeySequence,
31
- QShortcut,
32
- )
33
- from PyQt6.QtCore import Qt, QPointF, QTimer, QModelIndex
34
-
35
- from .photo_viewer import PhotoViewer
36
- from .sam_model import SamModel
37
- from .utils import mask_to_pixmap
38
- from .controls import ControlPanel, RightPanel
39
- from .custom_file_system_model import CustomFileSystemModel
40
- from .editable_vertex import EditableVertexItem
41
- from .hoverable_polygon_item import HoverablePolygonItem
42
- from .hoverable_pixelmap_item import HoverablePixmapItem
43
- from .numeric_table_widget_item import NumericTableWidgetItem
44
-
45
-
46
- class MainWindow(QMainWindow):
47
- def __init__(self, sam_model):
48
- super().__init__()
49
- self.setWindowTitle("LazyLabel by DNC")
50
-
51
- icon_path = os.path.join(
52
- os.path.dirname(__file__), "demo_pictures", "logo2.png"
53
- )
54
- if os.path.exists(icon_path):
55
- self.setWindowIcon(QIcon(icon_path))
56
-
57
- self.setGeometry(50, 50, 1600, 900)
58
-
59
- self.sam_model = sam_model
60
- self.mode = "sam_points"
61
- self.previous_mode = "sam_points"
62
- self.current_image_path = None
63
- self.current_file_index = QModelIndex()
64
-
65
- self.next_class_id = 0
66
- self.class_aliases = {}
67
-
68
- self._original_point_radius = 0.3
69
- self._original_line_thickness = 0.5
70
- self.point_radius = self._original_point_radius
71
- self.line_thickness = self._original_line_thickness
72
-
73
- self.pan_multiplier = 1.0
74
- self.polygon_join_threshold = 2
75
-
76
- self.point_items, self.positive_points, self.negative_points = [], [], []
77
- self.polygon_points, self.polygon_preview_items = [], []
78
- self.rubber_band_line = None
79
-
80
- self.segments, self.segment_items, self.highlight_items = [], {}, []
81
- self.is_dragging_polygon, self.drag_start_pos, self.drag_initial_vertices = (
82
- False,
83
- None,
84
- {},
85
- )
86
-
87
- self.control_panel = ControlPanel()
88
- self.right_panel = RightPanel()
89
- self.viewer = PhotoViewer(self)
90
- self.viewer.setMouseTracking(True)
91
- self.file_model = CustomFileSystemModel()
92
- self.right_panel.file_tree.setModel(self.file_model)
93
- self.right_panel.file_tree.setColumnWidth(0, 200)
94
- file_tree = self.right_panel.file_tree
95
- header = file_tree.header()
96
- header.setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
97
- header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
98
-
99
- main_layout = QHBoxLayout()
100
- main_layout.addWidget(self.control_panel)
101
- main_layout.addWidget(self.viewer, 1)
102
- main_layout.addWidget(self.right_panel)
103
- central_widget = QWidget()
104
- central_widget.setLayout(main_layout)
105
- self.setCentralWidget(central_widget)
106
-
107
- self.control_panel.device_label.setText(
108
- f"Device: {str(self.sam_model.device).upper()}"
109
- )
110
- self.setup_connections()
111
- self.set_sam_mode()
112
- self.set_annotation_size(10)
113
-
114
- def setup_connections(self):
115
- self._original_mouse_press = self.viewer.scene().mousePressEvent
116
- self._original_mouse_move = self.viewer.scene().mouseMoveEvent
117
- self._original_mouse_release = self.viewer.scene().mouseReleaseEvent
118
-
119
- self.viewer.scene().mousePressEvent = self.scene_mouse_press
120
- self.viewer.scene().mouseMoveEvent = self.scene_mouse_move
121
- self.viewer.scene().mouseReleaseEvent = self.scene_mouse_release
122
-
123
- self.right_panel.btn_open_folder.clicked.connect(self.open_folder_dialog)
124
- self.right_panel.file_tree.doubleClicked.connect(self.load_selected_image)
125
- self.right_panel.btn_merge_selection.clicked.connect(
126
- self.assign_selected_to_class
127
- )
128
- self.right_panel.btn_delete_selection.clicked.connect(
129
- self.delete_selected_segments
130
- )
131
- self.right_panel.segment_table.itemSelectionChanged.connect(
132
- self.highlight_selected_segments
133
- )
134
- self.right_panel.class_table.itemChanged.connect(self.handle_alias_change)
135
- self.right_panel.btn_reassign_classes.clicked.connect(self.reassign_class_ids)
136
- self.right_panel.class_filter_combo.currentIndexChanged.connect(
137
- self.update_segment_table
138
- )
139
-
140
- self.control_panel.btn_sam_mode.clicked.connect(self.set_sam_mode)
141
- self.control_panel.btn_polygon_mode.clicked.connect(self.set_polygon_mode)
142
- self.control_panel.btn_selection_mode.clicked.connect(
143
- self.toggle_selection_mode
144
- )
145
- self.control_panel.btn_clear_points.clicked.connect(self.clear_all_points)
146
- self.control_panel.btn_fit_view.clicked.connect(self.viewer.fitInView)
147
-
148
- self.control_panel.size_slider.valueChanged.connect(self.set_annotation_size)
149
- self.control_panel.pan_slider.valueChanged.connect(self.set_pan_multiplier)
150
- self.control_panel.join_slider.valueChanged.connect(
151
- self.set_polygon_join_threshold
152
- )
153
-
154
- self.control_panel.chk_save_npz.stateChanged.connect(
155
- self.handle_save_checkbox_change
156
- )
157
- self.control_panel.chk_save_txt.stateChanged.connect(
158
- self.handle_save_checkbox_change
159
- )
160
-
161
- self.control_panel.btn_toggle_visibility.clicked.connect(self.toggle_left_panel)
162
- self.right_panel.btn_toggle_visibility.clicked.connect(self.toggle_right_panel)
163
-
164
- QShortcut(QKeySequence(Qt.Key.Key_Right), self, self.load_next_image)
165
- QShortcut(QKeySequence(Qt.Key.Key_Left), self, self.load_previous_image)
166
- QShortcut(QKeySequence(Qt.Key.Key_1), self, self.set_sam_mode)
167
- QShortcut(QKeySequence(Qt.Key.Key_2), self, self.set_polygon_mode)
168
- QShortcut(QKeySequence(Qt.Key.Key_E), self, self.toggle_selection_mode)
169
- QShortcut(QKeySequence(Qt.Key.Key_Q), self, self.toggle_pan_mode)
170
- QShortcut(QKeySequence(Qt.Key.Key_R), self, self.toggle_edit_mode)
171
- QShortcut(QKeySequence(Qt.Key.Key_C), self, self.clear_all_points)
172
- QShortcut(QKeySequence(Qt.Key.Key_Escape), self, self.handle_escape_press)
173
- QShortcut(QKeySequence(Qt.Key.Key_V), self, self.delete_selected_segments)
174
- QShortcut(
175
- QKeySequence(Qt.Key.Key_Backspace), self, self.delete_selected_segments
176
- )
177
- QShortcut(QKeySequence(Qt.Key.Key_M), self, self.handle_merge_press)
178
- QShortcut(QKeySequence("Ctrl+Z"), self, self.undo_last_action)
179
- QShortcut(
180
- QKeySequence("Ctrl+A"), self, self.right_panel.segment_table.selectAll
181
- )
182
- QShortcut(QKeySequence(Qt.Key.Key_Space), self, self.handle_space_press)
183
- QShortcut(QKeySequence(Qt.Key.Key_Return), self, self.handle_enter_press)
184
- QShortcut(QKeySequence(Qt.Key.Key_Enter), self, self.handle_enter_press)
185
- QShortcut(QKeySequence(Qt.Key.Key_Period), self, self.viewer.fitInView)
186
-
187
- def toggle_left_panel(self):
188
- is_visible = self.control_panel.main_controls_widget.isVisible()
189
- self.control_panel.main_controls_widget.setVisible(not is_visible)
190
- if is_visible:
191
- self.control_panel.btn_toggle_visibility.setText("> Show")
192
- self.control_panel.setFixedWidth(
193
- self.control_panel.btn_toggle_visibility.sizeHint().width() + 20
194
- )
195
- else:
196
- self.control_panel.btn_toggle_visibility.setText("< Hide")
197
- self.control_panel.setFixedWidth(250)
198
-
199
- def toggle_right_panel(self):
200
- is_visible = self.right_panel.main_controls_widget.isVisible()
201
- self.right_panel.main_controls_widget.setVisible(not is_visible)
202
- layout = self.right_panel.v_layout
203
-
204
- if is_visible: # Content is now hidden
205
- layout.addStretch(1)
206
- self.right_panel.btn_toggle_visibility.setText("< Show")
207
- self.right_panel.setFixedWidth(
208
- self.right_panel.btn_toggle_visibility.sizeHint().width() + 20
209
- )
210
- else: # Content is now visible
211
- # Remove the stretch so the content can expand
212
- for i in range(layout.count()):
213
- item = layout.itemAt(i)
214
- if isinstance(item, QSpacerItem):
215
- layout.removeItem(item)
216
- break
217
- self.right_panel.btn_toggle_visibility.setText("Hide >")
218
- self.right_panel.setFixedWidth(350)
219
-
220
- def handle_save_checkbox_change(self):
221
- is_npz_checked = self.control_panel.chk_save_npz.isChecked()
222
- is_txt_checked = self.control_panel.chk_save_txt.isChecked()
223
-
224
- if not is_npz_checked and not is_txt_checked:
225
- sender = self.sender()
226
- if sender == self.control_panel.chk_save_npz:
227
- self.control_panel.chk_save_txt.setChecked(True)
228
- else:
229
- self.control_panel.chk_save_npz.setChecked(True)
230
-
231
- def set_annotation_size(self, value):
232
- multiplier = value / 10.0
233
- self.point_radius = self._original_point_radius * multiplier
234
- self.line_thickness = self._original_line_thickness * multiplier
235
-
236
- self.control_panel.size_label.setText(f"Annotation Size: {multiplier:.1f}x")
237
-
238
- if self.control_panel.size_slider.value() != value:
239
- self.control_panel.size_slider.setValue(value)
240
-
241
- self.display_all_segments()
242
- self.clear_all_points()
243
-
244
- def set_pan_multiplier(self, value):
245
- self.pan_multiplier = value / 10.0
246
- self.control_panel.pan_label.setText(f"Pan Speed: {self.pan_multiplier:.1f}x")
247
-
248
- def set_polygon_join_threshold(self, value):
249
- self.polygon_join_threshold = value
250
- self.control_panel.join_label.setText(f"Polygon Join Distance: {value}px")
251
-
252
- def handle_escape_press(self):
253
- self.right_panel.segment_table.clearSelection()
254
- self.right_panel.class_table.clearSelection()
255
- self.clear_all_points()
256
- self.viewer.setFocus()
257
-
258
- def handle_space_press(self):
259
- if self.mode == "polygon" and self.polygon_points:
260
- self.finalize_polygon()
261
- else:
262
- self.save_current_segment()
263
-
264
- def handle_enter_press(self):
265
- if self.mode == "polygon" and self.polygon_points:
266
- self.finalize_polygon()
267
- else:
268
- self.save_output_to_npz()
269
-
270
- def handle_merge_press(self):
271
- self.assign_selected_to_class()
272
- self.right_panel.segment_table.clearSelection()
273
-
274
- def show_notification(self, message, duration=3000):
275
- self.control_panel.notification_label.setText(message)
276
- QTimer.singleShot(
277
- duration, lambda: self.control_panel.notification_label.clear()
278
- )
279
-
280
- def _get_color_for_class(self, class_id):
281
- if class_id is None:
282
- return QColor.fromHsv(0, 0, 128)
283
- hue = int((class_id * 222.4922359) % 360)
284
- color = QColor.fromHsv(hue, 220, 220)
285
- if not color.isValid():
286
- return QColor(Qt.GlobalColor.white)
287
- return color
288
-
289
- def set_mode(self, mode_name, is_toggle=False):
290
- if self.mode == "selection" and mode_name not in ["selection", "edit"]:
291
- self.right_panel.segment_table.clearSelection()
292
- if self.mode == "edit" and mode_name != "edit":
293
- self.display_all_segments()
294
-
295
- if not is_toggle and self.mode not in ["selection", "edit"]:
296
- self.previous_mode = self.mode
297
-
298
- self.mode = mode_name
299
- self.control_panel.mode_label.setText(
300
- f"Mode: {mode_name.replace('_', ' ').title()}"
301
- )
302
- self.clear_all_points()
303
-
304
- cursor_map = {
305
- "sam_points": Qt.CursorShape.CrossCursor,
306
- "polygon": Qt.CursorShape.CrossCursor,
307
- "selection": Qt.CursorShape.ArrowCursor,
308
- "edit": Qt.CursorShape.SizeAllCursor,
309
- "pan": Qt.CursorShape.OpenHandCursor,
310
- }
311
- self.viewer.set_cursor(cursor_map.get(self.mode, Qt.CursorShape.ArrowCursor))
312
-
313
- self.viewer.setDragMode(
314
- self.viewer.DragMode.ScrollHandDrag
315
- if self.mode == "pan"
316
- else self.viewer.DragMode.NoDrag
317
- )
318
-
319
- def set_sam_mode(self):
320
- self.set_mode("sam_points")
321
-
322
- def set_polygon_mode(self):
323
- self.set_mode("polygon")
324
-
325
- def toggle_mode(self, new_mode):
326
- if self.mode == new_mode:
327
- self.set_mode(self.previous_mode, is_toggle=True)
328
- else:
329
- if self.mode not in ["selection", "edit"]:
330
- self.previous_mode = self.mode
331
- self.set_mode(new_mode, is_toggle=True)
332
-
333
- def toggle_pan_mode(self):
334
- self.toggle_mode("pan")
335
-
336
- def toggle_selection_mode(self):
337
- self.toggle_mode("selection")
338
-
339
- def toggle_edit_mode(self):
340
- selected_indices = self.get_selected_segment_indices()
341
-
342
- if self.mode == "edit":
343
- self.set_mode("selection", is_toggle=True)
344
- return
345
-
346
- if not selected_indices:
347
- self.show_notification("Select a polygon to edit.")
348
- return
349
-
350
- can_edit = any(
351
- self.segments[i].get("type") == "Polygon" for i in selected_indices
352
- )
353
-
354
- if not can_edit:
355
- self.show_notification("Only polygon segments can be edited.")
356
- return
357
-
358
- self.set_mode("edit", is_toggle=True)
359
- self.display_all_segments()
360
-
361
- def open_folder_dialog(self):
362
- folder_path = QFileDialog.getExistingDirectory(self, "Select Image Folder")
363
- if folder_path:
364
- self.right_panel.file_tree.setRootIndex(
365
- self.file_model.setRootPath(folder_path)
366
- )
367
- self.viewer.setFocus()
368
-
369
- def load_selected_image(self, index):
370
- if not index.isValid() or not self.file_model.isDir(index.parent()):
371
- return
372
-
373
- self.current_file_index = index
374
- path = self.file_model.filePath(index)
375
-
376
- if os.path.isfile(path) and path.lower().endswith(
377
- (".png", ".jpg", ".jpeg", ".tiff", ".tif")
378
- ):
379
- self.current_image_path = path
380
- pixmap = QPixmap(self.current_image_path)
381
- if not pixmap.isNull():
382
- self.reset_state()
383
- self.viewer.set_photo(pixmap)
384
- self.sam_model.set_image(self.current_image_path)
385
- self.load_class_aliases()
386
- self.load_existing_mask()
387
- self.right_panel.file_tree.setCurrentIndex(index)
388
- self.viewer.setFocus()
389
-
390
- def load_next_image(self):
391
- if not self.current_file_index.isValid():
392
- return
393
-
394
- if self.control_panel.chk_auto_save.isChecked():
395
- self.save_output_to_npz()
396
-
397
- row = self.current_file_index.row()
398
- parent = self.current_file_index.parent()
399
- if row + 1 < self.file_model.rowCount(parent):
400
- next_index = self.file_model.index(row + 1, 0, parent)
401
- self.load_selected_image(next_index)
402
-
403
- def load_previous_image(self):
404
- if not self.current_file_index.isValid():
405
- return
406
-
407
- if self.control_panel.chk_auto_save.isChecked():
408
- self.save_output_to_npz()
409
-
410
- row = self.current_file_index.row()
411
- parent = self.current_file_index.parent()
412
- if row > 0:
413
- prev_index = self.file_model.index(row - 1, 0, parent)
414
- self.load_selected_image(prev_index)
415
-
416
- def reset_state(self):
417
- self.clear_all_points()
418
- self.segments.clear()
419
- self.class_aliases.clear()
420
- self.next_class_id = 0
421
- self.update_all_lists()
422
- items_to_remove = [
423
- item
424
- for item in self.viewer.scene().items()
425
- if item is not self.viewer._pixmap_item
426
- ]
427
- for item in items_to_remove:
428
- self.viewer.scene().removeItem(item)
429
- self.segment_items.clear()
430
- self.highlight_items.clear()
431
-
432
- def keyPressEvent(self, event):
433
- key, mods = event.key(), event.modifiers()
434
-
435
- if event.isAutoRepeat() and key not in {
436
- Qt.Key.Key_W,
437
- Qt.Key.Key_A,
438
- Qt.Key.Key_S,
439
- Qt.Key.Key_D,
440
- }:
441
- return
442
-
443
- shift_multiplier = 5.0 if mods & Qt.KeyboardModifier.ShiftModifier else 1.0
444
-
445
- if key == Qt.Key.Key_W:
446
- amount = int(
447
- self.viewer.height() * 0.1 * self.pan_multiplier * shift_multiplier
448
- )
449
- self.viewer.verticalScrollBar().setValue(
450
- self.viewer.verticalScrollBar().value() - amount
451
- )
452
- elif key == Qt.Key.Key_S:
453
- amount = int(
454
- self.viewer.height() * 0.1 * self.pan_multiplier * shift_multiplier
455
- )
456
- self.viewer.verticalScrollBar().setValue(
457
- self.viewer.verticalScrollBar().value() + amount
458
- )
459
- elif key == Qt.Key.Key_A:
460
- amount = int(
461
- self.viewer.width() * 0.1 * self.pan_multiplier * shift_multiplier
462
- )
463
- self.viewer.horizontalScrollBar().setValue(
464
- self.viewer.horizontalScrollBar().value() - amount
465
- )
466
- elif key == Qt.Key.Key_D:
467
- amount = int(
468
- self.viewer.width() * 0.1 * self.pan_multiplier * shift_multiplier
469
- )
470
- self.viewer.horizontalScrollBar().setValue(
471
- self.viewer.horizontalScrollBar().value() + amount
472
- )
473
- elif (
474
- key == Qt.Key.Key_Equal or key == Qt.Key.Key_Plus
475
- ) and mods == Qt.KeyboardModifier.ControlModifier:
476
- current_val = self.control_panel.size_slider.value()
477
- self.control_panel.size_slider.setValue(current_val + 1)
478
- elif key == Qt.Key.Key_Minus and mods == Qt.KeyboardModifier.ControlModifier:
479
- current_val = self.control_panel.size_slider.value()
480
- self.control_panel.size_slider.setValue(current_val - 1)
481
- else:
482
- super().keyPressEvent(event)
483
-
484
- def scene_mouse_press(self, event):
485
- self._original_mouse_press(event)
486
- if event.isAccepted():
487
- return
488
-
489
- if self.mode == "pan":
490
- self.viewer.set_cursor(Qt.CursorShape.ClosedHandCursor)
491
-
492
- pos = event.scenePos()
493
- if (
494
- self.viewer._pixmap_item.pixmap().isNull()
495
- or not self.viewer._pixmap_item.pixmap().rect().contains(pos.toPoint())
496
- ):
497
- return
498
- if self.mode == "sam_points":
499
- if event.button() == Qt.MouseButton.LeftButton:
500
- self.add_point(pos, positive=True)
501
- self.update_segmentation()
502
- elif event.button() == Qt.MouseButton.RightButton:
503
- self.add_point(pos, positive=False)
504
- self.update_segmentation()
505
- elif self.mode == "polygon":
506
- if event.button() == Qt.MouseButton.LeftButton:
507
- self.handle_polygon_click(pos)
508
- elif self.mode == "selection":
509
- if event.button() == Qt.MouseButton.LeftButton:
510
- self.handle_segment_selection_click(pos)
511
- elif self.mode == "edit":
512
- self.drag_start_pos = pos
513
- self.is_dragging_polygon = True
514
- selected_indices = self.get_selected_segment_indices()
515
- self.drag_initial_vertices = {
516
- i: list(self.segments[i]["vertices"])
517
- for i in selected_indices
518
- if self.segments[i].get("type") == "Polygon"
519
- }
520
-
521
- def scene_mouse_move(self, event):
522
- pos = event.scenePos()
523
- if self.mode == "edit" and self.is_dragging_polygon:
524
- delta = pos - self.drag_start_pos
525
- for i, initial_verts in self.drag_initial_vertices.items():
526
- self.segments[i]["vertices"] = [
527
- QPointF(v.x() + delta.x(), v.y() + delta.y()) for v in initial_verts
528
- ]
529
- self.update_polygon_visuals(i)
530
- elif self.mode == "polygon" and self.polygon_points:
531
- if self.rubber_band_line is None:
532
- self.rubber_band_line = QGraphicsLineItem()
533
- line_color = QColor(Qt.GlobalColor.white)
534
- line_color.setAlpha(150)
535
- self.rubber_band_line.setPen(
536
- QPen(line_color, self.line_thickness, Qt.PenStyle.DotLine)
537
- )
538
- self.viewer.scene().addItem(self.rubber_band_line)
539
- self.rubber_band_line.setLine(
540
- self.polygon_points[-1].x(),
541
- self.polygon_points[-1].y(),
542
- pos.x(),
543
- pos.y(),
544
- )
545
- self.rubber_band_line.show()
546
- else:
547
- self._original_mouse_move(event)
548
-
549
- def scene_mouse_release(self, event):
550
- if self.mode == "pan":
551
- self.viewer.set_cursor(Qt.CursorShape.OpenHandCursor)
552
- if self.mode == "edit" and self.is_dragging_polygon:
553
- self.is_dragging_polygon = False
554
- self.drag_initial_vertices.clear()
555
- self._original_mouse_release(event)
556
-
557
- def undo_last_action(self):
558
- if self.mode == "polygon" and self.polygon_points:
559
- self.polygon_points.pop()
560
- for item in self.polygon_preview_items:
561
- if item.scene():
562
- self.viewer.scene().removeItem(item)
563
- self.polygon_preview_items.clear()
564
- for point in self.polygon_points:
565
- point_diameter = self.point_radius * 2
566
- point_color = QColor(Qt.GlobalColor.blue)
567
- point_color.setAlpha(150)
568
- dot = QGraphicsEllipseItem(
569
- point.x() - self.point_radius,
570
- point.y() - self.point_radius,
571
- point_diameter,
572
- point_diameter,
573
- )
574
- dot.setBrush(QBrush(point_color))
575
- dot.setPen(QPen(Qt.GlobalColor.transparent))
576
- self.viewer.scene().addItem(dot)
577
- self.polygon_preview_items.append(dot)
578
- self.draw_polygon_preview()
579
- elif self.mode == "sam_points" and self.point_items:
580
- item_to_remove = self.point_items.pop()
581
- point_pos = item_to_remove.rect().topLeft() + QPointF(
582
- self.point_radius, self.point_radius
583
- )
584
- point_coords = [int(point_pos.x()), int(point_pos.y())]
585
- if point_coords in self.positive_points:
586
- self.positive_points.remove(point_coords)
587
- elif point_coords in self.negative_points:
588
- self.negative_points.remove(point_coords)
589
- self.viewer.scene().removeItem(item_to_remove)
590
- self.update_segmentation()
591
-
592
- def _update_next_class_id(self):
593
- all_ids = {
594
- seg.get("class_id")
595
- for seg in self.segments
596
- if seg.get("class_id") is not None
597
- }
598
- if not all_ids:
599
- self.next_class_id = 0
600
- else:
601
- self.next_class_id = max(all_ids) + 1
602
-
603
- def finalize_polygon(self):
604
- if len(self.polygon_points) < 3:
605
- return
606
- if self.rubber_band_line:
607
- self.viewer.scene().removeItem(self.rubber_band_line)
608
- self.rubber_band_line = None
609
- self.segments.append(
610
- {
611
- "vertices": list(self.polygon_points),
612
- "type": "Polygon",
613
- "mask": None,
614
- "class_id": self.next_class_id,
615
- }
616
- )
617
- self._update_next_class_id()
618
- self.polygon_points.clear()
619
- for item in self.polygon_preview_items:
620
- self.viewer.scene().removeItem(item)
621
- self.polygon_preview_items.clear()
622
- self.update_all_lists()
623
-
624
- def handle_segment_selection_click(self, pos):
625
- x, y = int(pos.x()), int(pos.y())
626
- for i in range(len(self.segments) - 1, -1, -1):
627
- seg = self.segments[i]
628
- mask = (
629
- self.rasterize_polygon(seg["vertices"])
630
- if seg["type"] == "Polygon"
631
- else seg.get("mask")
632
- )
633
- if (
634
- mask is not None
635
- and y < mask.shape[0]
636
- and x < mask.shape[1]
637
- and mask[y, x]
638
- ):
639
- for j in range(self.right_panel.segment_table.rowCount()):
640
- item = self.right_panel.segment_table.item(j, 0)
641
- if item and item.data(Qt.ItemDataRole.UserRole) == i:
642
- table = self.right_panel.segment_table
643
- is_selected = table.item(j, 0).isSelected()
644
- range_to_select = QTableWidgetSelectionRange(
645
- j, 0, j, table.columnCount() - 1
646
- )
647
- table.setRangeSelected(range_to_select, not is_selected)
648
- return
649
- self.viewer.setFocus()
650
-
651
- def assign_selected_to_class(self):
652
- selected_indices = self.get_selected_segment_indices()
653
- if not selected_indices:
654
- return
655
-
656
- existing_class_ids = [
657
- self.segments[i]["class_id"]
658
- for i in selected_indices
659
- if self.segments[i].get("class_id") is not None
660
- ]
661
-
662
- if existing_class_ids:
663
- target_class_id = min(existing_class_ids)
664
- else:
665
- target_class_id = self.next_class_id
666
-
667
- for i in selected_indices:
668
- self.segments[i]["class_id"] = target_class_id
669
-
670
- self._update_next_class_id()
671
- self.update_all_lists()
672
- self.viewer.setFocus()
673
-
674
- def rasterize_polygon(self, vertices):
675
- if not vertices or self.viewer._pixmap_item.pixmap().isNull():
676
- return None
677
- h, w = (
678
- self.viewer._pixmap_item.pixmap().height(),
679
- self.viewer._pixmap_item.pixmap().width(),
680
- )
681
- points_np = np.array([[p.x(), p.y()] for p in vertices], dtype=np.int32)
682
- mask = np.zeros((h, w), dtype=np.uint8)
683
- cv2.fillPoly(mask, [points_np], 1)
684
- return mask.astype(bool)
685
-
686
- def display_all_segments(self):
687
- for i, items in self.segment_items.items():
688
- for item in items:
689
- self.viewer.scene().removeItem(item)
690
- self.segment_items.clear()
691
- selected_indices = self.get_selected_segment_indices()
692
-
693
- for i, seg_dict in enumerate(self.segments):
694
- self.segment_items[i] = []
695
- class_id = seg_dict.get("class_id")
696
- base_color = self._get_color_for_class(class_id)
697
-
698
- if seg_dict["type"] == "Polygon":
699
- poly_item = HoverablePolygonItem(QPolygonF(seg_dict["vertices"]))
700
- default_brush = QBrush(
701
- QColor(base_color.red(), base_color.green(), base_color.blue(), 70)
702
- )
703
- hover_brush = QBrush(
704
- QColor(base_color.red(), base_color.green(), base_color.blue(), 170)
705
- )
706
- poly_item.set_brushes(default_brush, hover_brush)
707
- poly_item.setPen(QPen(Qt.GlobalColor.transparent))
708
- self.viewer.scene().addItem(poly_item)
709
- self.segment_items[i].append(poly_item)
710
- base_color.setAlpha(150)
711
- vertex_color = QBrush(base_color)
712
- point_diameter = self.point_radius * 2
713
- for v in seg_dict["vertices"]:
714
- dot = QGraphicsEllipseItem(
715
- v.x() - self.point_radius,
716
- v.y() - self.point_radius,
717
- point_diameter,
718
- point_diameter,
719
- )
720
- dot.setBrush(vertex_color)
721
- dot.setPen(QPen(Qt.GlobalColor.transparent))
722
- self.viewer.scene().addItem(dot)
723
- self.segment_items[i].append(dot)
724
- if self.mode == "edit" and i in selected_indices:
725
- handle_diameter = self.point_radius * 2
726
- for idx, v in enumerate(seg_dict["vertices"]):
727
- vertex_item = EditableVertexItem(
728
- self,
729
- i,
730
- idx,
731
- -handle_diameter / 2,
732
- -handle_diameter / 2,
733
- handle_diameter,
734
- handle_diameter,
735
- )
736
- vertex_item.setPos(v)
737
- self.viewer.scene().addItem(vertex_item)
738
- self.segment_items[i].append(vertex_item)
739
- elif seg_dict.get("mask") is not None:
740
- default_pixmap = mask_to_pixmap(
741
- seg_dict["mask"], base_color.getRgb()[:3], alpha=70
742
- )
743
- hover_pixmap = mask_to_pixmap(
744
- seg_dict["mask"], base_color.getRgb()[:3], alpha=170
745
- )
746
- pixmap_item = HoverablePixmapItem()
747
- pixmap_item.set_pixmaps(default_pixmap, hover_pixmap)
748
- self.viewer.scene().addItem(pixmap_item)
749
- pixmap_item.setZValue(i + 1)
750
- self.segment_items[i].append(pixmap_item)
751
- self.highlight_selected_segments()
752
-
753
- def update_vertex_pos(self, seg_idx, vtx_idx, new_pos):
754
- self.segments[seg_idx]["vertices"][vtx_idx] = new_pos
755
- self.update_polygon_visuals(seg_idx)
756
-
757
- def update_polygon_visuals(self, segment_index):
758
- items = self.segment_items.get(segment_index, [])
759
- for item in items:
760
- if isinstance(item, HoverablePolygonItem):
761
- item.setPolygon(QPolygonF(self.segments[segment_index]["vertices"]))
762
- break
763
-
764
- def highlight_selected_segments(self):
765
- if hasattr(self, "highlight_items"):
766
- for item in self.highlight_items:
767
- self.viewer.scene().removeItem(item)
768
- self.highlight_items.clear()
769
- selected_indices = self.get_selected_segment_indices()
770
- for i in selected_indices:
771
- seg = self.segments[i]
772
- mask = (
773
- self.rasterize_polygon(seg["vertices"])
774
- if seg["type"] == "Polygon"
775
- else seg.get("mask")
776
- )
777
- if mask is not None:
778
- pixmap = mask_to_pixmap(mask, (255, 255, 255))
779
- highlight_item = self.viewer.scene().addPixmap(pixmap)
780
- highlight_item.setZValue(100)
781
- self.highlight_items.append(highlight_item)
782
-
783
- def update_all_lists(self):
784
- self.update_class_list()
785
- self.update_class_filter_combo()
786
- self.update_segment_table()
787
- self.display_all_segments()
788
-
789
- def update_segment_table(self):
790
- table = self.right_panel.segment_table
791
- table.blockSignals(True)
792
- selected_indices = self.get_selected_segment_indices()
793
- table.clearContents()
794
- table.setRowCount(0)
795
- filter_text = self.right_panel.class_filter_combo.currentText()
796
- show_all = filter_text == "All Classes"
797
- filter_class_id = -1
798
- if not show_all:
799
- try:
800
- filter_class_id = int(filter_text.split("(ID: ")[1][:-1])
801
- except (ValueError, IndexError):
802
- pass
803
-
804
- display_segments = []
805
- for i, seg in enumerate(self.segments):
806
- if show_all or seg.get("class_id") == filter_class_id:
807
- display_segments.append((i, seg))
808
-
809
- table.setRowCount(len(display_segments))
810
-
811
- for row, (original_index, seg) in enumerate(display_segments):
812
- class_id = seg.get("class_id")
813
- color = self._get_color_for_class(class_id)
814
- class_id_str = str(class_id) if class_id is not None else "N/A"
815
-
816
- alias_str = "N/A"
817
- if class_id is not None:
818
- alias_str = self.class_aliases.get(class_id, str(class_id))
819
- alias_item = QTableWidgetItem(alias_str)
820
-
821
- index_item = NumericTableWidgetItem(str(original_index + 1))
822
- class_item = NumericTableWidgetItem(class_id_str)
823
-
824
- index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
825
- class_item.setFlags(class_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
826
- alias_item.setFlags(alias_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
827
- index_item.setData(Qt.ItemDataRole.UserRole, original_index)
828
-
829
- table.setItem(row, 0, index_item)
830
- table.setItem(row, 1, class_item)
831
- table.setItem(row, 2, alias_item)
832
-
833
- for col in range(table.columnCount()):
834
- if table.item(row, col):
835
- table.item(row, col).setBackground(QBrush(color))
836
-
837
- table.setSortingEnabled(False)
838
- for row in range(table.rowCount()):
839
- item = table.item(row, 0)
840
- if item and item.data(Qt.ItemDataRole.UserRole) in selected_indices:
841
- table.selectRow(row)
842
- table.setSortingEnabled(True)
843
-
844
- table.blockSignals(False)
845
- self.viewer.setFocus()
846
-
847
- def update_class_list(self):
848
- class_table = self.right_panel.class_table
849
- class_table.blockSignals(True)
850
-
851
- preserved_aliases = self.class_aliases.copy()
852
- unique_class_ids = sorted(
853
- list(
854
- {
855
- seg.get("class_id")
856
- for seg in self.segments
857
- if seg.get("class_id") is not None
858
- }
859
- )
860
- )
861
-
862
- new_aliases = {}
863
- for cid in unique_class_ids:
864
- new_aliases[cid] = preserved_aliases.get(cid, str(cid))
865
-
866
- self.class_aliases = new_aliases
867
-
868
- class_table.clearContents()
869
- class_table.setRowCount(len(unique_class_ids))
870
- for row, cid in enumerate(unique_class_ids):
871
- alias_item = QTableWidgetItem(self.class_aliases.get(cid))
872
- id_item = QTableWidgetItem(str(cid))
873
- id_item.setFlags(id_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
874
- color = self._get_color_for_class(cid)
875
- alias_item.setBackground(QBrush(color))
876
- id_item.setBackground(QBrush(color))
877
- class_table.setItem(row, 0, alias_item)
878
- class_table.setItem(row, 1, id_item)
879
-
880
- class_table.blockSignals(False)
881
-
882
- def update_class_filter_combo(self):
883
- combo = self.right_panel.class_filter_combo
884
- unique_class_ids = sorted(
885
- list(
886
- {
887
- seg.get("class_id")
888
- for seg in self.segments
889
- if seg.get("class_id") is not None
890
- }
891
- )
892
- )
893
- current_selection = combo.currentText()
894
- combo.blockSignals(True)
895
- combo.clear()
896
- combo.addItem("All Classes")
897
- combo.addItems(
898
- [
899
- f"{self.class_aliases.get(cid, cid)} (ID: {cid})"
900
- for cid in unique_class_ids
901
- ]
902
- )
903
- if combo.findText(current_selection) > -1:
904
- combo.setCurrentText(current_selection)
905
- else:
906
- combo.setCurrentIndex(0)
907
- combo.blockSignals(False)
908
-
909
- def reassign_class_ids(self):
910
- class_table = self.right_panel.class_table
911
- ordered_ids = []
912
- for row in range(class_table.rowCount()):
913
- id_item = class_table.item(row, 1)
914
- if id_item and id_item.text():
915
- try:
916
- ordered_ids.append(int(id_item.text()))
917
- except ValueError:
918
- continue
919
- id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
920
- for seg in self.segments:
921
- old_id = seg.get("class_id")
922
- if old_id in id_map:
923
- seg["class_id"] = id_map[old_id]
924
- new_aliases = {
925
- id_map[old_id]: self.class_aliases.get(old_id, str(old_id))
926
- for old_id in ordered_ids
927
- if old_id in self.class_aliases
928
- }
929
- self.class_aliases = new_aliases
930
- self._update_next_class_id()
931
- self.update_all_lists()
932
- self.viewer.setFocus()
933
-
934
- def handle_alias_change(self, item):
935
- if item.column() != 0:
936
- return
937
- class_table = self.right_panel.class_table
938
- class_table.blockSignals(True)
939
- id_item = class_table.item(item.row(), 1)
940
- if id_item:
941
- try:
942
- class_id = int(id_item.text())
943
- self.class_aliases[class_id] = item.text()
944
- except (ValueError, AttributeError):
945
- pass
946
- class_table.blockSignals(False)
947
-
948
- self.update_class_filter_combo()
949
- self.update_segment_table()
950
-
951
- def get_selected_segment_indices(self):
952
- table = self.right_panel.segment_table
953
- selected_items = table.selectedItems()
954
- selected_rows = sorted(list({item.row() for item in selected_items}))
955
- return [
956
- table.item(row, 0).data(Qt.ItemDataRole.UserRole)
957
- for row in selected_rows
958
- if table.item(row, 0)
959
- ]
960
-
961
- def save_output_to_npz(self):
962
- save_npz = self.control_panel.chk_save_npz.isChecked()
963
- save_txt = self.control_panel.chk_save_txt.isChecked()
964
- save_aliases = self.control_panel.chk_save_class_aliases.isChecked()
965
-
966
- if not self.current_image_path or not any([save_npz, save_txt, save_aliases]):
967
- return
968
-
969
- self.right_panel.status_label.setText("Saving...")
970
- QApplication.processEvents()
971
-
972
- saved_something = False
973
-
974
- if save_npz or save_txt:
975
- if not self.segments:
976
- self.show_notification("No segments to save.")
977
- else:
978
- h, w = (
979
- self.viewer._pixmap_item.pixmap().height(),
980
- self.viewer._pixmap_item.pixmap().width(),
981
- )
982
- class_table = self.right_panel.class_table
983
- ordered_ids = [
984
- int(class_table.item(row, 1).text())
985
- for row in range(class_table.rowCount())
986
- if class_table.item(row, 1) is not None
987
- ]
988
-
989
- if not ordered_ids:
990
- self.show_notification("No classes defined for mask saving.")
991
- else:
992
- id_map = {
993
- old_id: new_id for new_id, old_id in enumerate(ordered_ids)
994
- }
995
- num_final_classes = len(ordered_ids)
996
- final_mask_tensor = np.zeros(
997
- (h, w, num_final_classes), dtype=np.uint8
998
- )
999
-
1000
- for seg in self.segments:
1001
- class_id = seg.get("class_id")
1002
- if class_id not in id_map:
1003
- continue
1004
- new_channel_idx = id_map[class_id]
1005
- mask = (
1006
- self.rasterize_polygon(seg["vertices"])
1007
- if seg["type"] == "Polygon"
1008
- else seg.get("mask")
1009
- )
1010
- if mask is not None:
1011
- final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
1012
- final_mask_tensor[:, :, new_channel_idx], mask
1013
- )
1014
- if save_npz:
1015
- npz_path = os.path.splitext(self.current_image_path)[0] + ".npz"
1016
- np.savez_compressed(
1017
- npz_path, mask=final_mask_tensor.astype(np.uint8)
1018
- )
1019
- self.file_model.update_cache_for_path(npz_path)
1020
-
1021
- self.file_model.set_highlighted_path(npz_path)
1022
- QTimer.singleShot(
1023
- 1500, lambda: self.file_model.set_highlighted_path(None)
1024
- )
1025
- saved_something = True
1026
- if save_txt:
1027
- if self.control_panel.chk_yolo_use_alias.isChecked():
1028
- class_labels = [
1029
- class_table.item(row, 0).text()
1030
- for row in range(class_table.rowCount())
1031
- ]
1032
- else:
1033
- class_labels = list(range(num_final_classes))
1034
-
1035
- txt_path = self.generate_yolo_annotations(
1036
- final_mask_tensor, class_labels
1037
- )
1038
- if txt_path:
1039
- self.file_model.update_cache_for_path(txt_path)
1040
- saved_something = True
1041
-
1042
- if save_aliases:
1043
- aliases_path = os.path.splitext(self.current_image_path)[0] + ".json"
1044
- aliases_to_save = {str(k): v for k, v in self.class_aliases.items()}
1045
- with open(aliases_path, "w") as f:
1046
- json.dump(aliases_to_save, f, indent=4)
1047
- saved_something = True
1048
-
1049
- if saved_something:
1050
- self.right_panel.status_label.setText("Saved!")
1051
- else:
1052
- self.right_panel.status_label.clear()
1053
-
1054
- QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
1055
-
1056
- def generate_yolo_annotations(self, mask_tensor, class_labels):
1057
- output_path = os.path.splitext(self.current_image_path)[0] + ".txt"
1058
- h, w, num_channels = mask_tensor.shape
1059
-
1060
- directory_path = os.path.dirname(output_path)
1061
- os.makedirs(directory_path, exist_ok=True)
1062
-
1063
- yolo_annotations = []
1064
- for channel in range(num_channels):
1065
- single_channel_image = mask_tensor[:, :, channel]
1066
- if not np.any(single_channel_image):
1067
- continue
1068
-
1069
- contours, _ = cv2.findContours(
1070
- single_channel_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
1071
- )
1072
-
1073
- class_label = class_labels[channel]
1074
- for contour in contours:
1075
- x, y, width, height = cv2.boundingRect(contour)
1076
- center_x = (x + width / 2) / w
1077
- center_y = (y + height / 2) / h
1078
- normalized_width = width / w
1079
- normalized_height = height / h
1080
- yolo_entry = f"{class_label} {center_x} {center_y} {normalized_width} {normalized_height}"
1081
- yolo_annotations.append(yolo_entry)
1082
-
1083
- if not yolo_annotations:
1084
- return None
1085
-
1086
- with open(output_path, "w") as file:
1087
- for annotation in yolo_annotations:
1088
- file.write(annotation + "\n")
1089
-
1090
- return output_path
1091
-
1092
- def save_current_segment(self):
1093
- if (
1094
- self.mode != "sam_points"
1095
- or not hasattr(self, "preview_mask_item")
1096
- or not self.preview_mask_item
1097
- ):
1098
- return
1099
- mask = self.sam_model.predict(self.positive_points, self.negative_points)
1100
- if mask is not None:
1101
- self.segments.append(
1102
- {
1103
- "mask": mask,
1104
- "type": "SAM",
1105
- "vertices": None,
1106
- "class_id": self.next_class_id,
1107
- }
1108
- )
1109
- self._update_next_class_id()
1110
- self.clear_all_points()
1111
- self.update_all_lists()
1112
-
1113
- def delete_selected_segments(self):
1114
- selected_indices = self.get_selected_segment_indices()
1115
- if not selected_indices:
1116
- return
1117
- for i in sorted(selected_indices, reverse=True):
1118
- del self.segments[i]
1119
- self._update_next_class_id()
1120
- self.update_all_lists()
1121
- self.viewer.setFocus()
1122
-
1123
- def load_class_aliases(self):
1124
- if not self.current_image_path:
1125
- return
1126
- json_path = os.path.splitext(self.current_image_path)[0] + ".json"
1127
- if os.path.exists(json_path):
1128
- try:
1129
- with open(json_path, "r") as f:
1130
- loaded_aliases = json.load(f)
1131
- # JSON loads keys as strings, convert them to int
1132
- self.class_aliases = {int(k): v for k, v in loaded_aliases.items()}
1133
- except (json.JSONDecodeError, ValueError) as e:
1134
- print(f"Error loading class aliases from {json_path}: {e}")
1135
- self.class_aliases.clear()
1136
-
1137
- def load_existing_mask(self):
1138
- if not self.current_image_path:
1139
- return
1140
- npz_path = os.path.splitext(self.current_image_path)[0] + ".npz"
1141
- if os.path.exists(npz_path):
1142
- with np.load(npz_path) as data:
1143
- if "mask" in data:
1144
- mask_data = data["mask"]
1145
- if mask_data.ndim == 2:
1146
- mask_data = np.expand_dims(mask_data, axis=-1)
1147
- num_classes = mask_data.shape[2]
1148
- for i in range(num_classes):
1149
- class_mask = mask_data[:, :, i].astype(bool)
1150
- if np.any(class_mask):
1151
- self.segments.append(
1152
- {
1153
- "mask": class_mask,
1154
- "type": "Loaded",
1155
- "vertices": None,
1156
- "class_id": i,
1157
- }
1158
- )
1159
- self._update_next_class_id()
1160
- self.update_all_lists()
1161
-
1162
- def add_point(self, pos, positive):
1163
- point_list = self.positive_points if positive else self.negative_points
1164
- point_list.append([int(pos.x()), int(pos.y())])
1165
- point_color = (
1166
- QColor(Qt.GlobalColor.green) if positive else QColor(Qt.GlobalColor.red)
1167
- )
1168
- point_color.setAlpha(150)
1169
- point_diameter = self.point_radius * 2
1170
- point_item = QGraphicsEllipseItem(
1171
- pos.x() - self.point_radius,
1172
- pos.y() - self.point_radius,
1173
- point_diameter,
1174
- point_diameter,
1175
- )
1176
- point_item.setBrush(QBrush(point_color))
1177
- point_item.setPen(QPen(Qt.GlobalColor.transparent))
1178
- self.viewer.scene().addItem(point_item)
1179
- self.point_items.append(point_item)
1180
-
1181
- def update_segmentation(self):
1182
- if hasattr(self, "preview_mask_item") and self.preview_mask_item:
1183
- self.viewer.scene().removeItem(self.preview_mask_item)
1184
- if not self.positive_points:
1185
- return
1186
- mask = self.sam_model.predict(self.positive_points, self.negative_points)
1187
- if mask is not None:
1188
- pixmap = mask_to_pixmap(mask, (255, 255, 0))
1189
- self.preview_mask_item = self.viewer.scene().addPixmap(pixmap)
1190
- self.preview_mask_item.setZValue(50)
1191
-
1192
- def clear_all_points(self):
1193
- if self.rubber_band_line:
1194
- self.viewer.scene().removeItem(self.rubber_band_line)
1195
- self.rubber_band_line = None
1196
- self.positive_points.clear()
1197
- self.negative_points.clear()
1198
- for item in self.point_items:
1199
- self.viewer.scene().removeItem(item)
1200
- self.point_items.clear()
1201
- self.polygon_points.clear()
1202
- for item in self.polygon_preview_items:
1203
- self.viewer.scene().removeItem(item)
1204
- self.polygon_preview_items.clear()
1205
- if hasattr(self, "preview_mask_item") and self.preview_mask_item:
1206
- self.viewer.scene().removeItem(self.preview_mask_item)
1207
- self.preview_mask_item = None
1208
-
1209
- def handle_polygon_click(self, pos):
1210
- if self.polygon_points and (
1211
- (
1212
- (pos.x() - self.polygon_points[0].x()) ** 2
1213
- + (pos.y() - self.polygon_points[0].y()) ** 2
1214
- )
1215
- < self.polygon_join_threshold**2
1216
- ):
1217
- if len(self.polygon_points) > 2:
1218
- self.finalize_polygon()
1219
- return
1220
- self.polygon_points.append(pos)
1221
- point_diameter = self.point_radius * 2
1222
- point_color = QColor(Qt.GlobalColor.blue)
1223
- point_color.setAlpha(150)
1224
- dot = QGraphicsEllipseItem(
1225
- pos.x() - self.point_radius,
1226
- pos.y() - self.point_radius,
1227
- point_diameter,
1228
- point_diameter,
1229
- )
1230
- dot.setBrush(QBrush(point_color))
1231
- dot.setPen(QPen(Qt.GlobalColor.transparent))
1232
- self.viewer.scene().addItem(dot)
1233
- self.polygon_preview_items.append(dot)
1234
- self.draw_polygon_preview()
1235
-
1236
- def draw_polygon_preview(self):
1237
- for item in self.polygon_preview_items:
1238
- if not isinstance(item, QGraphicsEllipseItem):
1239
- if item.scene():
1240
- self.viewer.scene().removeItem(item)
1241
- self.polygon_preview_items = [
1242
- item
1243
- for item in self.polygon_preview_items
1244
- if isinstance(item, QGraphicsEllipseItem)
1245
- ]
1246
- if len(self.polygon_points) > 2:
1247
- preview_poly = QGraphicsPolygonItem(QPolygonF(self.polygon_points))
1248
- preview_poly.setBrush(QBrush(QColor(0, 255, 255, 100)))
1249
- preview_poly.setPen(QPen(Qt.GlobalColor.transparent))
1250
- self.viewer.scene().addItem(preview_poly)
1251
- self.polygon_preview_items.append(preview_poly)
5
+ from PyQt6.QtWidgets import QApplication
1252
6
 
1253
- if len(self.polygon_points) > 1:
1254
- line_color = QColor(Qt.GlobalColor.cyan)
1255
- line_color.setAlpha(150)
1256
- for i in range(len(self.polygon_points) - 1):
1257
- line = QGraphicsLineItem(
1258
- self.polygon_points[i].x(),
1259
- self.polygon_points[i].y(),
1260
- self.polygon_points[i + 1].x(),
1261
- self.polygon_points[i + 1].y(),
1262
- )
1263
- line.setPen(QPen(line_color, self.line_thickness))
1264
- self.viewer.scene().addItem(line)
1265
- self.polygon_preview_items.append(line)
7
+ from .ui.main_window import MainWindow
1266
8
 
1267
9
 
1268
10
  def main():
11
+ """Main application entry point."""
1269
12
  app = QApplication(sys.argv)
1270
13
  qdarktheme.setup_theme()
1271
- sam_model = SamModel(model_type="vit_h")
1272
- main_win = MainWindow(sam_model)
1273
- main_win.show()
14
+
15
+ main_window = MainWindow()
16
+ main_window.show()
17
+
1274
18
  sys.exit(app.exec())
1275
19
 
1276
20
 
1277
21
  if __name__ == "__main__":
1278
- main()
22
+ main()