lazylabel-gui 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lazylabel/controls.py +108 -0
- lazylabel/custom_file_system_model.py +57 -0
- lazylabel/editable_vertex.py +25 -0
- lazylabel/hoverable_polygon_item.py +23 -0
- lazylabel/main.py +864 -0
- lazylabel/numeric_table_widget_item.py +10 -0
- lazylabel/photo_viewer.py +51 -0
- lazylabel/reorderable_class_table.py +58 -0
- lazylabel/sam_model.py +70 -0
- lazylabel/utils.py +12 -0
- lazylabel_gui-1.0.0.dist-info/METADATA +147 -0
- lazylabel_gui-1.0.0.dist-info/RECORD +16 -0
- lazylabel_gui-1.0.0.dist-info/WHEEL +5 -0
- lazylabel_gui-1.0.0.dist-info/entry_points.txt +2 -0
- lazylabel_gui-1.0.0.dist-info/licenses/LICENSE +21 -0
- lazylabel_gui-1.0.0.dist-info/top_level.txt +1 -0
lazylabel/main.py
ADDED
@@ -0,0 +1,864 @@
|
|
1
|
+
import sys
|
2
|
+
import os
|
3
|
+
import numpy as np
|
4
|
+
import qdarktheme
|
5
|
+
import cv2
|
6
|
+
from PyQt6.QtWidgets import (
|
7
|
+
QApplication,
|
8
|
+
QMainWindow,
|
9
|
+
QWidget,
|
10
|
+
QHBoxLayout,
|
11
|
+
QFileDialog,
|
12
|
+
QGraphicsItem,
|
13
|
+
QGraphicsEllipseItem,
|
14
|
+
QGraphicsLineItem,
|
15
|
+
QTableWidgetItem,
|
16
|
+
QGraphicsPolygonItem,
|
17
|
+
QTableWidgetSelectionRange,
|
18
|
+
)
|
19
|
+
from PyQt6.QtGui import QPixmap, QColor, QPen, QBrush, QPolygonF, QIcon
|
20
|
+
from PyQt6.QtCore import Qt, QPointF, QTimer
|
21
|
+
|
22
|
+
# Relative imports for package structure
|
23
|
+
from .photo_viewer import PhotoViewer
|
24
|
+
from .sam_model import SamModel
|
25
|
+
from .utils import mask_to_pixmap
|
26
|
+
from .controls import ControlPanel, RightPanel
|
27
|
+
from .custom_file_system_model import CustomFileSystemModel
|
28
|
+
from .editable_vertex import EditableVertexItem
|
29
|
+
from .hoverable_polygon_item import HoverablePolygonItem
|
30
|
+
from .numeric_table_widget_item import NumericTableWidgetItem
|
31
|
+
|
32
|
+
|
33
|
+
class MainWindow(QMainWindow):
|
34
|
+
def __init__(self, sam_model):
|
35
|
+
super().__init__()
|
36
|
+
self.setWindowTitle("LazyLabel by DNC")
|
37
|
+
|
38
|
+
icon_path = os.path.join(
|
39
|
+
os.path.dirname(__file__), "demo_pictures", "logo2.png"
|
40
|
+
)
|
41
|
+
if os.path.exists(icon_path):
|
42
|
+
self.setWindowIcon(QIcon(icon_path))
|
43
|
+
|
44
|
+
self.setGeometry(50, 50, 1600, 900)
|
45
|
+
|
46
|
+
# The SamModel instance is now passed in
|
47
|
+
self.sam_model = sam_model
|
48
|
+
self.mode = "sam_points"
|
49
|
+
self.previous_mode = "sam_points"
|
50
|
+
self.current_image_path = None
|
51
|
+
self.current_file_index = None
|
52
|
+
self.next_class_id = 0
|
53
|
+
|
54
|
+
self.point_items, self.positive_points, self.negative_points = [], [], []
|
55
|
+
self.polygon_points, self.polygon_preview_items = [], []
|
56
|
+
self.rubber_band_line = None
|
57
|
+
|
58
|
+
self.segments, self.segment_items, self.highlight_items = [], {}, []
|
59
|
+
self.is_dragging_polygon, self.drag_start_pos, self.drag_initial_vertices = (
|
60
|
+
False,
|
61
|
+
None,
|
62
|
+
{},
|
63
|
+
)
|
64
|
+
|
65
|
+
self.control_panel = ControlPanel()
|
66
|
+
self.right_panel = RightPanel()
|
67
|
+
self.viewer = PhotoViewer(self)
|
68
|
+
self.viewer.setMouseTracking(True)
|
69
|
+
self.file_model = CustomFileSystemModel()
|
70
|
+
self.right_panel.file_tree.setModel(self.file_model)
|
71
|
+
self.right_panel.file_tree.setColumnWidth(0, 200)
|
72
|
+
|
73
|
+
main_layout = QHBoxLayout()
|
74
|
+
main_layout.addWidget(self.control_panel)
|
75
|
+
main_layout.addWidget(self.viewer, 1)
|
76
|
+
main_layout.addWidget(self.right_panel)
|
77
|
+
central_widget = QWidget()
|
78
|
+
central_widget.setLayout(main_layout)
|
79
|
+
self.setCentralWidget(central_widget)
|
80
|
+
|
81
|
+
self.control_panel.device_label.setText(
|
82
|
+
f"Device: {str(self.sam_model.device).upper()}"
|
83
|
+
)
|
84
|
+
self.setup_connections()
|
85
|
+
self.set_sam_mode()
|
86
|
+
|
87
|
+
def setup_connections(self):
|
88
|
+
self._original_mouse_press = self.viewer.scene().mousePressEvent
|
89
|
+
self._original_mouse_move = self.viewer.scene().mouseMoveEvent
|
90
|
+
self._original_mouse_release = self.viewer.scene().mouseReleaseEvent
|
91
|
+
|
92
|
+
self.viewer.scene().mousePressEvent = self.scene_mouse_press
|
93
|
+
self.viewer.scene().mouseMoveEvent = self.scene_mouse_move
|
94
|
+
self.viewer.scene().mouseReleaseEvent = self.scene_mouse_release
|
95
|
+
|
96
|
+
self.right_panel.btn_open_folder.clicked.connect(self.open_folder_dialog)
|
97
|
+
self.right_panel.file_tree.doubleClicked.connect(self.load_selected_image)
|
98
|
+
self.right_panel.btn_merge_selection.clicked.connect(
|
99
|
+
self.assign_selected_to_class
|
100
|
+
)
|
101
|
+
self.right_panel.btn_delete_selection.clicked.connect(
|
102
|
+
self.delete_selected_segments
|
103
|
+
)
|
104
|
+
self.right_panel.segment_table.itemSelectionChanged.connect(
|
105
|
+
self.highlight_selected_segments
|
106
|
+
)
|
107
|
+
self.right_panel.segment_table.itemChanged.connect(self.handle_class_id_change)
|
108
|
+
self.right_panel.btn_reassign_classes.clicked.connect(self.reassign_class_ids)
|
109
|
+
self.right_panel.class_filter_combo.currentIndexChanged.connect(
|
110
|
+
self.update_segment_table
|
111
|
+
)
|
112
|
+
|
113
|
+
self.control_panel.btn_sam_mode.clicked.connect(self.set_sam_mode)
|
114
|
+
self.control_panel.btn_polygon_mode.clicked.connect(self.set_polygon_mode)
|
115
|
+
self.control_panel.btn_selection_mode.clicked.connect(
|
116
|
+
self.toggle_selection_mode
|
117
|
+
)
|
118
|
+
self.control_panel.btn_clear_points.clicked.connect(self.clear_all_points)
|
119
|
+
|
120
|
+
def set_mode(self, mode_name, is_toggle=False):
|
121
|
+
if self.mode == "edit" and mode_name != "edit":
|
122
|
+
self.display_all_segments()
|
123
|
+
if not is_toggle and self.mode not in ["pan", "selection", "edit"]:
|
124
|
+
self.previous_mode = self.mode
|
125
|
+
self.mode = mode_name
|
126
|
+
self.control_panel.mode_label.setText(
|
127
|
+
f"Mode: {mode_name.replace('_', ' ').title()}"
|
128
|
+
)
|
129
|
+
self.clear_all_points()
|
130
|
+
self.viewer.setDragMode(
|
131
|
+
self.viewer.DragMode.ScrollHandDrag
|
132
|
+
if self.mode == "pan"
|
133
|
+
else self.viewer.DragMode.NoDrag
|
134
|
+
)
|
135
|
+
|
136
|
+
def set_sam_mode(self):
|
137
|
+
self.set_mode("sam_points")
|
138
|
+
|
139
|
+
def set_polygon_mode(self):
|
140
|
+
self.set_mode("polygon")
|
141
|
+
|
142
|
+
def toggle_mode(self, new_mode):
|
143
|
+
if self.mode == new_mode:
|
144
|
+
self.set_mode(self.previous_mode, is_toggle=True)
|
145
|
+
else:
|
146
|
+
if self.mode not in ["pan", "selection", "edit"]:
|
147
|
+
self.previous_mode = self.mode
|
148
|
+
self.set_mode(new_mode, is_toggle=True)
|
149
|
+
|
150
|
+
def toggle_pan_mode(self):
|
151
|
+
self.toggle_mode("pan")
|
152
|
+
|
153
|
+
def toggle_selection_mode(self):
|
154
|
+
self.toggle_mode("selection")
|
155
|
+
|
156
|
+
def toggle_edit_mode(self):
|
157
|
+
selected_indices = self.get_selected_segment_indices()
|
158
|
+
can_edit = any(
|
159
|
+
self.segments[i].get("type") == "Polygon" for i in selected_indices
|
160
|
+
)
|
161
|
+
if self.mode == "edit":
|
162
|
+
self.set_mode("selection", is_toggle=True)
|
163
|
+
elif self.mode == "selection" and can_edit:
|
164
|
+
self.set_mode("edit", is_toggle=True)
|
165
|
+
self.display_all_segments()
|
166
|
+
|
167
|
+
def open_folder_dialog(self):
|
168
|
+
folder_path = QFileDialog.getExistingDirectory(self, "Select Image Folder")
|
169
|
+
if folder_path:
|
170
|
+
self.right_panel.file_tree.setRootIndex(
|
171
|
+
self.file_model.setRootPath(folder_path)
|
172
|
+
)
|
173
|
+
self.viewer.setFocus()
|
174
|
+
|
175
|
+
def load_selected_image(self, index):
|
176
|
+
if not index.isValid():
|
177
|
+
return
|
178
|
+
|
179
|
+
self.current_file_index = index
|
180
|
+
path = self.file_model.filePath(index)
|
181
|
+
|
182
|
+
if os.path.isfile(path) and path.lower().endswith((".png", ".jpg", ".jpeg")):
|
183
|
+
self.current_image_path = path
|
184
|
+
pixmap = QPixmap(self.current_image_path)
|
185
|
+
if not pixmap.isNull():
|
186
|
+
self.reset_state()
|
187
|
+
self.viewer.set_photo(pixmap)
|
188
|
+
self.sam_model.set_image(self.current_image_path)
|
189
|
+
self.load_existing_mask()
|
190
|
+
self.viewer.setFocus()
|
191
|
+
|
192
|
+
def reset_state(self):
|
193
|
+
self.clear_all_points()
|
194
|
+
self.segments.clear()
|
195
|
+
self.next_class_id = 0
|
196
|
+
self.update_all_lists()
|
197
|
+
items_to_remove = [
|
198
|
+
item
|
199
|
+
for item in self.viewer.scene().items()
|
200
|
+
if item is not self.viewer._pixmap_item
|
201
|
+
]
|
202
|
+
for item in items_to_remove:
|
203
|
+
self.viewer.scene().removeItem(item)
|
204
|
+
self.segment_items.clear()
|
205
|
+
self.highlight_items.clear()
|
206
|
+
|
207
|
+
def keyPressEvent(self, event):
|
208
|
+
key, mods = event.key(), event.modifiers()
|
209
|
+
if event.isAutoRepeat():
|
210
|
+
return
|
211
|
+
if key == Qt.Key.Key_W:
|
212
|
+
self.viewer.verticalScrollBar().setValue(
|
213
|
+
self.viewer.verticalScrollBar().value()
|
214
|
+
- int(self.viewer.height() * 0.1)
|
215
|
+
)
|
216
|
+
elif key == Qt.Key.Key_S and not mods:
|
217
|
+
self.viewer.verticalScrollBar().setValue(
|
218
|
+
self.viewer.verticalScrollBar().value()
|
219
|
+
+ int(self.viewer.height() * 0.1)
|
220
|
+
)
|
221
|
+
elif key == Qt.Key.Key_A and not (mods & Qt.KeyboardModifier.ControlModifier):
|
222
|
+
self.viewer.horizontalScrollBar().setValue(
|
223
|
+
self.viewer.horizontalScrollBar().value()
|
224
|
+
- int(self.viewer.width() * 0.1)
|
225
|
+
)
|
226
|
+
elif key == Qt.Key.Key_D:
|
227
|
+
self.viewer.horizontalScrollBar().setValue(
|
228
|
+
self.viewer.horizontalScrollBar().value()
|
229
|
+
+ int(self.viewer.width() * 0.1)
|
230
|
+
)
|
231
|
+
elif key == Qt.Key.Key_1:
|
232
|
+
self.set_sam_mode()
|
233
|
+
elif key == Qt.Key.Key_2:
|
234
|
+
self.set_polygon_mode()
|
235
|
+
elif key == Qt.Key.Key_E:
|
236
|
+
self.toggle_selection_mode()
|
237
|
+
elif key == Qt.Key.Key_Q:
|
238
|
+
self.toggle_pan_mode()
|
239
|
+
elif key == Qt.Key.Key_R:
|
240
|
+
self.toggle_edit_mode()
|
241
|
+
elif key == Qt.Key.Key_C:
|
242
|
+
self.clear_all_points()
|
243
|
+
elif key == Qt.Key.Key_V or key == Qt.Key.Key_Backspace:
|
244
|
+
self.delete_selected_segments()
|
245
|
+
elif key == Qt.Key.Key_M:
|
246
|
+
self.assign_selected_to_class()
|
247
|
+
elif key == Qt.Key.Key_Z and mods == Qt.KeyboardModifier.ControlModifier:
|
248
|
+
self.undo_last_action()
|
249
|
+
elif key == Qt.Key.Key_A and mods == Qt.KeyboardModifier.ControlModifier:
|
250
|
+
self.right_panel.segment_table.selectAll()
|
251
|
+
elif key == Qt.Key.Key_Space:
|
252
|
+
self.save_current_segment()
|
253
|
+
elif key == Qt.Key.Key_Return or key == Qt.Key.Key_Enter:
|
254
|
+
self.save_output_to_npz()
|
255
|
+
|
256
|
+
def scene_mouse_press(self, event):
|
257
|
+
self._original_mouse_press(event)
|
258
|
+
if event.isAccepted():
|
259
|
+
return
|
260
|
+
pos = event.scenePos()
|
261
|
+
if (
|
262
|
+
self.viewer._pixmap_item.pixmap().isNull()
|
263
|
+
or not self.viewer._pixmap_item.pixmap().rect().contains(pos.toPoint())
|
264
|
+
):
|
265
|
+
return
|
266
|
+
if self.mode == "sam_points":
|
267
|
+
if event.button() == Qt.MouseButton.LeftButton:
|
268
|
+
self.add_point(pos, positive=True)
|
269
|
+
self.update_segmentation()
|
270
|
+
elif event.button() == Qt.MouseButton.RightButton:
|
271
|
+
self.add_point(pos, positive=False)
|
272
|
+
self.update_segmentation()
|
273
|
+
elif self.mode == "polygon":
|
274
|
+
if event.button() == Qt.MouseButton.LeftButton:
|
275
|
+
self.handle_polygon_click(pos)
|
276
|
+
elif self.mode == "selection":
|
277
|
+
if event.button() == Qt.MouseButton.LeftButton:
|
278
|
+
self.handle_segment_selection_click(pos)
|
279
|
+
elif self.mode == "edit":
|
280
|
+
self.drag_start_pos = pos
|
281
|
+
self.is_dragging_polygon = True
|
282
|
+
selected_indices = self.get_selected_segment_indices()
|
283
|
+
self.drag_initial_vertices = {
|
284
|
+
i: list(self.segments[i]["vertices"])
|
285
|
+
for i in selected_indices
|
286
|
+
if self.segments[i].get("type") == "Polygon"
|
287
|
+
}
|
288
|
+
|
289
|
+
def scene_mouse_move(self, event):
|
290
|
+
pos = event.scenePos()
|
291
|
+
if self.mode == "edit" and self.is_dragging_polygon:
|
292
|
+
delta = pos - self.drag_start_pos
|
293
|
+
for i, initial_verts in self.drag_initial_vertices.items():
|
294
|
+
self.segments[i]["vertices"] = [
|
295
|
+
QPointF(v.x() + delta.x(), v.y() + delta.y()) for v in initial_verts
|
296
|
+
]
|
297
|
+
self.update_polygon_visuals(i)
|
298
|
+
elif self.mode == "polygon" and self.polygon_points:
|
299
|
+
if self.rubber_band_line is None:
|
300
|
+
self.rubber_band_line = QGraphicsLineItem()
|
301
|
+
self.rubber_band_line.setPen(
|
302
|
+
QPen(Qt.GlobalColor.white, 2, Qt.PenStyle.DotLine)
|
303
|
+
)
|
304
|
+
self.viewer.scene().addItem(self.rubber_band_line)
|
305
|
+
self.rubber_band_line.setLine(
|
306
|
+
self.polygon_points[-1].x(),
|
307
|
+
self.polygon_points[-1].y(),
|
308
|
+
pos.x(),
|
309
|
+
pos.y(),
|
310
|
+
)
|
311
|
+
self.rubber_band_line.show()
|
312
|
+
else:
|
313
|
+
self._original_mouse_move(event)
|
314
|
+
|
315
|
+
def scene_mouse_release(self, event):
|
316
|
+
if self.mode == "edit" and self.is_dragging_polygon:
|
317
|
+
self.is_dragging_polygon = False
|
318
|
+
self.drag_initial_vertices.clear()
|
319
|
+
self._original_mouse_release(event)
|
320
|
+
|
321
|
+
def undo_last_action(self):
|
322
|
+
if self.mode == "polygon" and self.polygon_points:
|
323
|
+
self.polygon_points.pop()
|
324
|
+
if self.polygon_preview_items:
|
325
|
+
self.viewer.scene().removeItem(self.polygon_preview_items.pop())
|
326
|
+
self.draw_polygon_preview()
|
327
|
+
elif self.mode == "sam_points" and self.point_items:
|
328
|
+
item_to_remove = self.point_items.pop()
|
329
|
+
point_pos = item_to_remove.rect().topLeft() + QPointF(4, 4)
|
330
|
+
point_coords = [int(point_pos.x()), int(point_pos.y())]
|
331
|
+
if point_coords in self.positive_points:
|
332
|
+
self.positive_points.remove(point_coords)
|
333
|
+
elif point_coords in self.negative_points:
|
334
|
+
self.negative_points.remove(point_coords)
|
335
|
+
self.viewer.scene().removeItem(item_to_remove)
|
336
|
+
self.update_segmentation()
|
337
|
+
|
338
|
+
def finalize_polygon(self):
|
339
|
+
if len(self.polygon_points) < 3:
|
340
|
+
return
|
341
|
+
if self.rubber_band_line:
|
342
|
+
self.viewer.scene().removeItem(self.rubber_band_line)
|
343
|
+
self.rubber_band_line = None
|
344
|
+
self.segments.append(
|
345
|
+
{
|
346
|
+
"vertices": list(self.polygon_points),
|
347
|
+
"type": "Polygon",
|
348
|
+
"mask": None,
|
349
|
+
"class_id": self.next_class_id,
|
350
|
+
}
|
351
|
+
)
|
352
|
+
self.next_class_id += 1
|
353
|
+
self.polygon_points.clear()
|
354
|
+
for item in self.polygon_preview_items:
|
355
|
+
self.viewer.scene().removeItem(item)
|
356
|
+
self.polygon_preview_items.clear()
|
357
|
+
self.update_all_lists()
|
358
|
+
|
359
|
+
def handle_segment_selection_click(self, pos):
|
360
|
+
x, y = int(pos.x()), int(pos.y())
|
361
|
+
for i in range(len(self.segments) - 1, -1, -1):
|
362
|
+
seg = self.segments[i]
|
363
|
+
mask = (
|
364
|
+
self.rasterize_polygon(seg["vertices"])
|
365
|
+
if seg["type"] == "Polygon"
|
366
|
+
else seg.get("mask")
|
367
|
+
)
|
368
|
+
if (
|
369
|
+
mask is not None
|
370
|
+
and y < mask.shape[0]
|
371
|
+
and x < mask.shape[1]
|
372
|
+
and mask[y, x]
|
373
|
+
):
|
374
|
+
for j in range(self.right_panel.segment_table.rowCount()):
|
375
|
+
item = self.right_panel.segment_table.item(j, 0)
|
376
|
+
if item and item.data(Qt.ItemDataRole.UserRole) == i:
|
377
|
+
table = self.right_panel.segment_table
|
378
|
+
is_selected = table.item(j, 0).isSelected()
|
379
|
+
range_to_select = QTableWidgetSelectionRange(
|
380
|
+
j, 0, j, table.columnCount() - 1
|
381
|
+
)
|
382
|
+
table.setRangeSelected(range_to_select, not is_selected)
|
383
|
+
return
|
384
|
+
self.viewer.setFocus()
|
385
|
+
|
386
|
+
def assign_selected_to_class(self):
|
387
|
+
selected_indices = self.get_selected_segment_indices()
|
388
|
+
if not selected_indices:
|
389
|
+
return
|
390
|
+
target_class_id = self.segments[selected_indices[0]]["class_id"]
|
391
|
+
for i in selected_indices:
|
392
|
+
self.segments[i]["class_id"] = target_class_id
|
393
|
+
self.update_all_lists()
|
394
|
+
self.viewer.setFocus()
|
395
|
+
|
396
|
+
def rasterize_polygon(self, vertices):
|
397
|
+
if not vertices or self.viewer._pixmap_item.pixmap().isNull():
|
398
|
+
return None
|
399
|
+
h, w = (
|
400
|
+
self.viewer._pixmap_item.pixmap().height(),
|
401
|
+
self.viewer._pixmap_item.pixmap().width(),
|
402
|
+
)
|
403
|
+
points_np = np.array([[p.x(), p.y()] for p in vertices], dtype=np.int32)
|
404
|
+
mask = np.zeros((h, w), dtype=np.uint8)
|
405
|
+
cv2.fillPoly(mask, [points_np], 1)
|
406
|
+
return mask.astype(bool)
|
407
|
+
|
408
|
+
def display_all_segments(self):
|
409
|
+
for i, items in self.segment_items.items():
|
410
|
+
for item in items:
|
411
|
+
self.viewer.scene().removeItem(item)
|
412
|
+
self.segment_items.clear()
|
413
|
+
selected_indices = self.get_selected_segment_indices()
|
414
|
+
|
415
|
+
unique_class_ids = sorted(
|
416
|
+
list(
|
417
|
+
{
|
418
|
+
seg.get("class_id")
|
419
|
+
for seg in self.segments
|
420
|
+
if seg.get("class_id") is not None
|
421
|
+
}
|
422
|
+
)
|
423
|
+
)
|
424
|
+
num_classes = len(unique_class_ids) if unique_class_ids else 1
|
425
|
+
class_id_to_hue_index = {
|
426
|
+
class_id: i for i, class_id in enumerate(unique_class_ids)
|
427
|
+
}
|
428
|
+
|
429
|
+
for i, seg_dict in enumerate(self.segments):
|
430
|
+
self.segment_items[i] = []
|
431
|
+
class_id = seg_dict.get("class_id", 0)
|
432
|
+
hue_index = class_id_to_hue_index.get(class_id, 0)
|
433
|
+
hue = int((hue_index * 360 / num_classes)) % 360
|
434
|
+
base_color = QColor.fromHsv(hue, 220, 220)
|
435
|
+
|
436
|
+
if seg_dict["type"] == "Polygon":
|
437
|
+
poly_item = HoverablePolygonItem(QPolygonF(seg_dict["vertices"]))
|
438
|
+
default_brush = QBrush(
|
439
|
+
QColor(base_color.red(), base_color.green(), base_color.blue(), 70)
|
440
|
+
)
|
441
|
+
hover_brush = QBrush(
|
442
|
+
QColor(base_color.red(), base_color.green(), base_color.blue(), 170)
|
443
|
+
)
|
444
|
+
poly_item.set_brushes(default_brush, hover_brush)
|
445
|
+
poly_item.setPen(QPen(Qt.GlobalColor.transparent))
|
446
|
+
self.viewer.scene().addItem(poly_item)
|
447
|
+
self.segment_items[i].append(poly_item)
|
448
|
+
vertex_color = QBrush(base_color)
|
449
|
+
for v in seg_dict["vertices"]:
|
450
|
+
dot = QGraphicsEllipseItem(v.x() - 3, v.y() - 3, 6, 6)
|
451
|
+
dot.setBrush(vertex_color)
|
452
|
+
self.viewer.scene().addItem(dot)
|
453
|
+
self.segment_items[i].append(dot)
|
454
|
+
if self.mode == "edit" and i in selected_indices:
|
455
|
+
for idx, v in enumerate(seg_dict["vertices"]):
|
456
|
+
vertex_item = EditableVertexItem(self, i, idx, -4, -4, 8, 8)
|
457
|
+
vertex_item.setPos(v)
|
458
|
+
self.viewer.scene().addItem(vertex_item)
|
459
|
+
self.segment_items[i].append(vertex_item)
|
460
|
+
elif seg_dict.get("mask") is not None:
|
461
|
+
pixmap = mask_to_pixmap(seg_dict["mask"], base_color.getRgb()[:3])
|
462
|
+
pixmap_item = self.viewer.scene().addPixmap(pixmap)
|
463
|
+
pixmap_item.setZValue(i + 1)
|
464
|
+
self.segment_items[i].append(pixmap_item)
|
465
|
+
self.highlight_selected_segments()
|
466
|
+
|
467
|
+
def update_vertex_pos(self, seg_idx, vtx_idx, new_pos):
|
468
|
+
self.segments[seg_idx]["vertices"][vtx_idx] = new_pos
|
469
|
+
self.update_polygon_visuals(seg_idx)
|
470
|
+
|
471
|
+
def update_polygon_visuals(self, segment_index):
|
472
|
+
items = self.segment_items.get(segment_index, [])
|
473
|
+
for item in items:
|
474
|
+
if isinstance(item, HoverablePolygonItem):
|
475
|
+
item.setPolygon(QPolygonF(self.segments[segment_index]["vertices"]))
|
476
|
+
break
|
477
|
+
|
478
|
+
def highlight_selected_segments(self):
|
479
|
+
if hasattr(self, "highlight_items"):
|
480
|
+
for item in self.highlight_items:
|
481
|
+
self.viewer.scene().removeItem(item)
|
482
|
+
self.highlight_items.clear()
|
483
|
+
selected_indices = self.get_selected_segment_indices()
|
484
|
+
for i in selected_indices:
|
485
|
+
seg = self.segments[i]
|
486
|
+
mask = (
|
487
|
+
self.rasterize_polygon(seg["vertices"])
|
488
|
+
if seg["type"] == "Polygon"
|
489
|
+
else seg.get("mask")
|
490
|
+
)
|
491
|
+
if mask is not None:
|
492
|
+
pixmap = mask_to_pixmap(mask, (255, 255, 255))
|
493
|
+
highlight_item = self.viewer.scene().addPixmap(pixmap)
|
494
|
+
highlight_item.setZValue(100)
|
495
|
+
self.highlight_items.append(highlight_item)
|
496
|
+
|
497
|
+
def update_all_lists(self):
|
498
|
+
self.update_class_filter_combo()
|
499
|
+
self.update_segment_table()
|
500
|
+
self.update_class_list()
|
501
|
+
self.display_all_segments()
|
502
|
+
|
503
|
+
def update_segment_table(self):
|
504
|
+
table = self.right_panel.segment_table
|
505
|
+
table.blockSignals(True)
|
506
|
+
selected_indices = self.get_selected_segment_indices()
|
507
|
+
table.clearContents()
|
508
|
+
table.setRowCount(0)
|
509
|
+
filter_text = self.right_panel.class_filter_combo.currentText()
|
510
|
+
show_all = filter_text == "All Classes"
|
511
|
+
filter_class_id = -1
|
512
|
+
if not show_all:
|
513
|
+
try:
|
514
|
+
filter_class_id = int(filter_text.split(" ")[1])
|
515
|
+
except (ValueError, IndexError):
|
516
|
+
pass
|
517
|
+
|
518
|
+
display_segments = []
|
519
|
+
for i, seg in enumerate(self.segments):
|
520
|
+
if show_all or seg.get("class_id") == filter_class_id:
|
521
|
+
display_segments.append((i, seg))
|
522
|
+
|
523
|
+
table.setRowCount(len(display_segments))
|
524
|
+
|
525
|
+
unique_class_ids = sorted(
|
526
|
+
list(
|
527
|
+
{
|
528
|
+
s.get("class_id")
|
529
|
+
for s in self.segments
|
530
|
+
if s.get("class_id") is not None
|
531
|
+
}
|
532
|
+
)
|
533
|
+
)
|
534
|
+
num_classes = len(unique_class_ids) if unique_class_ids else 1
|
535
|
+
class_id_to_hue_index = {cid: i for i, cid in enumerate(unique_class_ids)}
|
536
|
+
|
537
|
+
for row, (original_index, seg) in enumerate(display_segments):
|
538
|
+
class_id = seg.get("class_id", 0)
|
539
|
+
hue_index = class_id_to_hue_index.get(class_id, 0)
|
540
|
+
hue = int((hue_index * 360 / num_classes)) % 360
|
541
|
+
color = QColor.fromHsv(hue, 150, 100)
|
542
|
+
|
543
|
+
index_item = NumericTableWidgetItem(str(original_index + 1))
|
544
|
+
class_item = NumericTableWidgetItem(str(class_id))
|
545
|
+
type_item = QTableWidgetItem(seg["type"])
|
546
|
+
|
547
|
+
index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
548
|
+
type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
549
|
+
index_item.setData(Qt.ItemDataRole.UserRole, original_index)
|
550
|
+
|
551
|
+
table.setItem(row, 0, index_item)
|
552
|
+
table.setItem(row, 1, class_item)
|
553
|
+
table.setItem(row, 2, type_item)
|
554
|
+
for col in range(3):
|
555
|
+
table.item(row, col).setBackground(QBrush(color))
|
556
|
+
|
557
|
+
table.setSortingEnabled(False)
|
558
|
+
for row in range(table.rowCount()):
|
559
|
+
if table.item(row, 0).data(Qt.ItemDataRole.UserRole) in selected_indices:
|
560
|
+
table.selectRow(row)
|
561
|
+
table.setSortingEnabled(True)
|
562
|
+
|
563
|
+
table.blockSignals(False)
|
564
|
+
self.viewer.setFocus()
|
565
|
+
|
566
|
+
def update_class_list(self):
|
567
|
+
class_table = self.right_panel.class_table
|
568
|
+
class_table.blockSignals(True)
|
569
|
+
unique_class_ids = sorted(
|
570
|
+
list(
|
571
|
+
{
|
572
|
+
seg.get("class_id")
|
573
|
+
for seg in self.segments
|
574
|
+
if seg.get("class_id") is not None
|
575
|
+
}
|
576
|
+
)
|
577
|
+
)
|
578
|
+
class_table.setRowCount(len(unique_class_ids))
|
579
|
+
num_classes = len(unique_class_ids) if unique_class_ids else 1
|
580
|
+
class_id_to_hue_index = {
|
581
|
+
class_id: i for i, class_id in enumerate(unique_class_ids)
|
582
|
+
}
|
583
|
+
for row, cid in enumerate(unique_class_ids):
|
584
|
+
item = QTableWidgetItem(str(cid))
|
585
|
+
item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
586
|
+
hue_index = class_id_to_hue_index.get(cid, 0)
|
587
|
+
hue = int((hue_index * 360 / num_classes)) % 360
|
588
|
+
color = QColor.fromHsv(hue, 150, 100)
|
589
|
+
item.setBackground(QBrush(color))
|
590
|
+
class_table.setItem(row, 0, item)
|
591
|
+
class_table.blockSignals(False)
|
592
|
+
|
593
|
+
def update_class_filter_combo(self):
|
594
|
+
combo = self.right_panel.class_filter_combo
|
595
|
+
unique_class_ids = sorted(
|
596
|
+
list(
|
597
|
+
{
|
598
|
+
seg.get("class_id")
|
599
|
+
for seg in self.segments
|
600
|
+
if seg.get("class_id") is not None
|
601
|
+
}
|
602
|
+
)
|
603
|
+
)
|
604
|
+
current_selection = combo.currentText()
|
605
|
+
combo.blockSignals(True)
|
606
|
+
combo.clear()
|
607
|
+
combo.addItem("All Classes")
|
608
|
+
combo.addItems([f"Class {cid}" for cid in unique_class_ids])
|
609
|
+
if combo.findText(current_selection) > -1:
|
610
|
+
combo.setCurrentText(current_selection)
|
611
|
+
else:
|
612
|
+
combo.setCurrentIndex(0)
|
613
|
+
combo.blockSignals(False)
|
614
|
+
|
615
|
+
def reassign_class_ids(self):
|
616
|
+
class_table = self.right_panel.class_table
|
617
|
+
ordered_ids = [
|
618
|
+
int(class_table.item(row, 0).text())
|
619
|
+
for row in range(class_table.rowCount())
|
620
|
+
if class_table.item(row, 0) is not None
|
621
|
+
]
|
622
|
+
id_map = {old_id: new_id for new_id, old_id in enumerate(ordered_ids)}
|
623
|
+
for seg in self.segments:
|
624
|
+
old_id = seg.get("class_id")
|
625
|
+
if old_id in id_map:
|
626
|
+
seg["class_id"] = id_map[old_id]
|
627
|
+
self.next_class_id = len(ordered_ids)
|
628
|
+
self.update_all_lists()
|
629
|
+
self.viewer.setFocus()
|
630
|
+
|
631
|
+
def handle_class_id_change(self, item):
|
632
|
+
if item.column() != 1:
|
633
|
+
return
|
634
|
+
table = self.right_panel.segment_table
|
635
|
+
table.blockSignals(True)
|
636
|
+
try:
|
637
|
+
new_class_id = int(item.text())
|
638
|
+
original_index = table.item(item.row(), 0).data(Qt.ItemDataRole.UserRole)
|
639
|
+
self.segments[original_index]["class_id"] = new_class_id
|
640
|
+
if new_class_id >= self.next_class_id:
|
641
|
+
self.next_class_id = new_class_id + 1
|
642
|
+
self.update_all_lists()
|
643
|
+
except (ValueError, TypeError):
|
644
|
+
original_index = table.item(item.row(), 0).data(Qt.ItemDataRole.UserRole)
|
645
|
+
item.setText(str(self.segments[original_index]["class_id"]))
|
646
|
+
table.blockSignals(False)
|
647
|
+
self.viewer.setFocus()
|
648
|
+
|
649
|
+
def get_selected_segment_indices(self):
|
650
|
+
table = self.right_panel.segment_table
|
651
|
+
selected_items = table.selectedItems()
|
652
|
+
selected_rows = sorted(list({item.row() for item in selected_items}))
|
653
|
+
return [
|
654
|
+
table.item(row, 0).data(Qt.ItemDataRole.UserRole)
|
655
|
+
for row in selected_rows
|
656
|
+
if table.item(row, 0)
|
657
|
+
]
|
658
|
+
|
659
|
+
def save_output_to_npz(self):
|
660
|
+
if not self.segments or not self.current_image_path:
|
661
|
+
return
|
662
|
+
self.right_panel.status_label.setText("Saving...")
|
663
|
+
QApplication.processEvents()
|
664
|
+
|
665
|
+
output_path = os.path.splitext(self.current_image_path)[0] + ".npz"
|
666
|
+
h, w = (
|
667
|
+
self.viewer._pixmap_item.pixmap().height(),
|
668
|
+
self.viewer._pixmap_item.pixmap().width(),
|
669
|
+
)
|
670
|
+
unique_class_ids = sorted(
|
671
|
+
list(
|
672
|
+
{
|
673
|
+
seg["class_id"]
|
674
|
+
for seg in self.segments
|
675
|
+
if seg.get("class_id") is not None
|
676
|
+
}
|
677
|
+
)
|
678
|
+
)
|
679
|
+
if not unique_class_ids:
|
680
|
+
self.right_panel.status_label.setText("Save failed: No classes.")
|
681
|
+
QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
|
682
|
+
return
|
683
|
+
|
684
|
+
id_map = {old_id: new_id for new_id, old_id in enumerate(unique_class_ids)}
|
685
|
+
num_final_classes = len(unique_class_ids)
|
686
|
+
final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
|
687
|
+
|
688
|
+
for seg in self.segments:
|
689
|
+
class_id = seg.get("class_id")
|
690
|
+
if class_id not in id_map:
|
691
|
+
continue
|
692
|
+
new_channel_idx = id_map[class_id]
|
693
|
+
mask = (
|
694
|
+
self.rasterize_polygon(seg["vertices"])
|
695
|
+
if seg["type"] == "Polygon"
|
696
|
+
else seg.get("mask")
|
697
|
+
)
|
698
|
+
if mask is not None:
|
699
|
+
final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
|
700
|
+
final_mask_tensor[:, :, new_channel_idx], mask
|
701
|
+
)
|
702
|
+
|
703
|
+
np.savez_compressed(output_path, mask=final_mask_tensor.astype(np.uint8))
|
704
|
+
self.file_model.setRootPath(self.file_model.rootPath())
|
705
|
+
|
706
|
+
self.right_panel.status_label.setText("Saved!")
|
707
|
+
QTimer.singleShot(3000, lambda: self.right_panel.status_label.clear())
|
708
|
+
|
709
|
+
def save_current_segment(self):
|
710
|
+
if (
|
711
|
+
self.mode != "sam_points"
|
712
|
+
or not hasattr(self, "preview_mask_item")
|
713
|
+
or not self.preview_mask_item
|
714
|
+
):
|
715
|
+
return
|
716
|
+
mask = self.sam_model.predict(self.positive_points, self.negative_points)
|
717
|
+
if mask is not None:
|
718
|
+
self.segments.append(
|
719
|
+
{
|
720
|
+
"mask": mask,
|
721
|
+
"type": "SAM",
|
722
|
+
"vertices": None,
|
723
|
+
"class_id": self.next_class_id,
|
724
|
+
}
|
725
|
+
)
|
726
|
+
self.next_class_id += 1
|
727
|
+
self.clear_all_points()
|
728
|
+
self.update_all_lists()
|
729
|
+
|
730
|
+
def delete_selected_segments(self):
|
731
|
+
selected_indices = self.get_selected_segment_indices()
|
732
|
+
if not selected_indices:
|
733
|
+
return
|
734
|
+
for i in sorted(selected_indices, reverse=True):
|
735
|
+
del self.segments[i]
|
736
|
+
self.update_all_lists()
|
737
|
+
self.viewer.setFocus()
|
738
|
+
|
739
|
+
def load_existing_mask(self):
|
740
|
+
if not self.current_image_path:
|
741
|
+
return
|
742
|
+
npz_path = os.path.splitext(self.current_image_path)[0] + ".npz"
|
743
|
+
if os.path.exists(npz_path):
|
744
|
+
with np.load(npz_path) as data:
|
745
|
+
if "mask" in data:
|
746
|
+
mask_data = data["mask"]
|
747
|
+
if mask_data.ndim == 2:
|
748
|
+
mask_data = np.expand_dims(mask_data, axis=-1)
|
749
|
+
num_classes = mask_data.shape[2]
|
750
|
+
for i in range(num_classes):
|
751
|
+
class_mask = mask_data[:, :, i].astype(bool)
|
752
|
+
if np.any(class_mask):
|
753
|
+
self.segments.append(
|
754
|
+
{
|
755
|
+
"mask": class_mask,
|
756
|
+
"type": "Loaded",
|
757
|
+
"vertices": None,
|
758
|
+
"class_id": i,
|
759
|
+
}
|
760
|
+
)
|
761
|
+
self.next_class_id = num_classes
|
762
|
+
self.update_all_lists()
|
763
|
+
|
764
|
+
def add_point(self, pos, positive):
|
765
|
+
point_list = self.positive_points if positive else self.negative_points
|
766
|
+
point_list.append([int(pos.x()), int(pos.y())])
|
767
|
+
color = Qt.GlobalColor.green if positive else Qt.GlobalColor.red
|
768
|
+
point_item = QGraphicsEllipseItem(pos.x() - 4, pos.y() - 4, 8, 8)
|
769
|
+
point_item.setBrush(QBrush(color))
|
770
|
+
point_item.setPen(QPen(Qt.GlobalColor.white))
|
771
|
+
self.viewer.scene().addItem(point_item)
|
772
|
+
self.point_items.append(point_item)
|
773
|
+
|
774
|
+
def update_segmentation(self):
|
775
|
+
if hasattr(self, "preview_mask_item") and self.preview_mask_item:
|
776
|
+
self.viewer.scene().removeItem(self.preview_mask_item)
|
777
|
+
if not self.positive_points:
|
778
|
+
return
|
779
|
+
mask = self.sam_model.predict(self.positive_points, self.negative_points)
|
780
|
+
if mask is not None:
|
781
|
+
pixmap = mask_to_pixmap(mask, (255, 255, 0))
|
782
|
+
self.preview_mask_item = self.viewer.scene().addPixmap(pixmap)
|
783
|
+
self.preview_mask_item.setZValue(50)
|
784
|
+
|
785
|
+
def clear_all_points(self):
|
786
|
+
if self.rubber_band_line:
|
787
|
+
self.viewer.scene().removeItem(self.rubber_band_line)
|
788
|
+
self.rubber_band_line = None
|
789
|
+
self.positive_points.clear()
|
790
|
+
self.negative_points.clear()
|
791
|
+
for item in self.point_items:
|
792
|
+
self.viewer.scene().removeItem(item)
|
793
|
+
self.point_items.clear()
|
794
|
+
self.polygon_points.clear()
|
795
|
+
for item in self.polygon_preview_items:
|
796
|
+
self.viewer.scene().removeItem(item)
|
797
|
+
self.polygon_preview_items.clear()
|
798
|
+
if hasattr(self, "preview_mask_item") and self.preview_mask_item:
|
799
|
+
self.viewer.scene().removeItem(self.preview_mask_item)
|
800
|
+
self.preview_mask_item = None
|
801
|
+
|
802
|
+
def handle_polygon_click(self, pos):
|
803
|
+
if self.polygon_points and (
|
804
|
+
(
|
805
|
+
(pos.x() - self.polygon_points[0].x()) ** 2
|
806
|
+
+ (pos.y() - self.polygon_points[0].y()) ** 2
|
807
|
+
)
|
808
|
+
< 25
|
809
|
+
):
|
810
|
+
if len(self.polygon_points) > 2:
|
811
|
+
self.finalize_polygon()
|
812
|
+
return
|
813
|
+
self.polygon_points.append(pos)
|
814
|
+
dot = QGraphicsEllipseItem(pos.x() - 2, pos.y() - 2, 4, 4)
|
815
|
+
dot.setBrush(QBrush(Qt.GlobalColor.blue))
|
816
|
+
dot.setPen(QPen(Qt.GlobalColor.cyan))
|
817
|
+
self.viewer.scene().addItem(dot)
|
818
|
+
self.polygon_preview_items.append(dot)
|
819
|
+
self.draw_polygon_preview()
|
820
|
+
|
821
|
+
def draw_polygon_preview(self):
|
822
|
+
if self.rubber_band_line:
|
823
|
+
self.viewer.scene().removeItem(self.rubber_band_line)
|
824
|
+
self.rubber_band_line = None
|
825
|
+
for item in self.polygon_preview_items:
|
826
|
+
if not isinstance(item, QGraphicsEllipseItem):
|
827
|
+
self.viewer.scene().removeItem(item)
|
828
|
+
self.polygon_preview_items = [
|
829
|
+
item
|
830
|
+
for item in self.polygon_preview_items
|
831
|
+
if isinstance(item, QGraphicsEllipseItem)
|
832
|
+
]
|
833
|
+
|
834
|
+
if len(self.polygon_points) > 2:
|
835
|
+
preview_poly = QGraphicsPolygonItem(QPolygonF(self.polygon_points))
|
836
|
+
preview_poly.setBrush(QBrush(QColor(0, 255, 255, 100)))
|
837
|
+
preview_poly.setPen(QPen(Qt.GlobalColor.transparent))
|
838
|
+
self.viewer.scene().addItem(preview_poly)
|
839
|
+
self.polygon_preview_items.append(preview_poly)
|
840
|
+
|
841
|
+
if len(self.polygon_points) > 1:
|
842
|
+
for i in range(len(self.polygon_points) - 1):
|
843
|
+
line = QGraphicsLineItem(
|
844
|
+
self.polygon_points[i].x(),
|
845
|
+
self.polygon_points[i].y(),
|
846
|
+
self.polygon_points[i + 1].x(),
|
847
|
+
self.polygon_points[i + 1].y(),
|
848
|
+
)
|
849
|
+
line.setPen(QPen(Qt.GlobalColor.cyan, 2))
|
850
|
+
self.viewer.scene().addItem(line)
|
851
|
+
self.polygon_preview_items.append(line)
|
852
|
+
|
853
|
+
|
854
|
+
def main():
|
855
|
+
app = QApplication(sys.argv)
|
856
|
+
qdarktheme.setup_theme()
|
857
|
+
sam_model = SamModel(model_type="vit_h") # one-time check/download
|
858
|
+
main_win = MainWindow(sam_model)
|
859
|
+
main_win.show()
|
860
|
+
sys.exit(app.exec())
|
861
|
+
|
862
|
+
|
863
|
+
if __name__ == "__main__":
|
864
|
+
main()
|