coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
- coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
- coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
- coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
- coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
- coralnet_toolbox/CoralNet/QtDownload.py +2 -1
- coralnet_toolbox/Explorer/QtDataItem.py +1 -1
- coralnet_toolbox/Explorer/QtExplorer.py +159 -17
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +160 -86
- coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
- coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
- coralnet_toolbox/IO/QtOpenProject.py +46 -78
- coralnet_toolbox/IO/QtSaveProject.py +18 -43
- coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
- coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
- coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
- coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
- coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
- coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
- coralnet_toolbox/QtAnnotationWindow.py +42 -14
- coralnet_toolbox/QtEventFilter.py +19 -2
- coralnet_toolbox/QtImageWindow.py +134 -86
- coralnet_toolbox/QtLabelWindow.py +14 -2
- coralnet_toolbox/QtMainWindow.py +122 -9
- coralnet_toolbox/QtProgressBar.py +52 -27
- coralnet_toolbox/Rasters/QtRaster.py +59 -7
- coralnet_toolbox/Rasters/RasterTableModel.py +42 -14
- coralnet_toolbox/SAM/QtBatchInference.py +0 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
- coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
- coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1634 -0
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +107 -154
- coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
- coralnet_toolbox/SeeAnything/__init__.py +2 -0
- coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
- coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
- coralnet_toolbox/Tools/QtSAMTool.py +222 -57
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +223 -55
- coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
- coralnet_toolbox/Tools/QtSelectTool.py +27 -3
- coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
- coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
- coralnet_toolbox/Tools/__init__.py +2 -0
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +137 -47
- coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +56 -53
- coralnet_toolbox-0.0.72.dist-info/METADATA +0 -341
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,8 @@ import warnings
|
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
|
7
|
-
from PyQt5.QtCore import Qt, pyqtSignal, QObject, QPointF, QTimer, pyqtProperty
|
8
7
|
from PyQt5.QtGui import QColor, QPolygonF, QPen, QBrush
|
8
|
+
from PyQt5.QtCore import Qt, pyqtSignal, QObject, QPointF, QTimer, pyqtProperty
|
9
9
|
from PyQt5.QtWidgets import (QMessageBox, QGraphicsEllipseItem, QGraphicsRectItem,
|
10
10
|
QGraphicsPolygonItem, QGraphicsScene, QGraphicsItemGroup)
|
11
11
|
|
@@ -54,12 +54,11 @@ class Annotation(QObject):
|
|
54
54
|
self.annotation_size = None
|
55
55
|
self.tolerance = 0.1 # Default detail level for simplification/densification
|
56
56
|
|
57
|
-
# Attributes to store the graphics items for center/centroid
|
57
|
+
# Attributes to store the graphics items for center/centroid and bounding box
|
58
58
|
self.center_graphics_item = None
|
59
59
|
self.bounding_box_graphics_item = None
|
60
|
-
self.polygon_graphics_item = None
|
61
60
|
|
62
|
-
#
|
61
|
+
# Group for all graphics items
|
63
62
|
self.graphics_item_group = None
|
64
63
|
|
65
64
|
# Animation properties
|
@@ -87,6 +86,10 @@ class Annotation(QObject):
|
|
87
86
|
def get_polygon(self):
|
88
87
|
"""Get the polygon representation of this annotation."""
|
89
88
|
raise NotImplementedError("Subclasses must implement this method.")
|
89
|
+
|
90
|
+
def get_painter_path(self):
|
91
|
+
"""Get the QPainterPath representation of this annotation."""
|
92
|
+
raise NotImplementedError("Subclasses must implement this method.")
|
90
93
|
|
91
94
|
def get_bounding_box_top_left(self):
|
92
95
|
"""Get the top-left corner of the annotation's bounding box."""
|
@@ -125,6 +128,11 @@ class Annotation(QObject):
|
|
125
128
|
def cut(cls, annotations: list, cutting_points: list):
|
126
129
|
"""Cut multiple annotations using specified cutting points."""
|
127
130
|
raise NotImplementedError("Subclasses must implement this method.")
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def subtract(cls, base_annotation, cutter_annotations: list):
|
134
|
+
"""Subtract cutter annotations from a base annotation."""
|
135
|
+
raise NotImplementedError("Subclasses must implement this method.")
|
128
136
|
|
129
137
|
def show_warning_message(self):
|
130
138
|
"""Display a warning message about removing machine suggestions when altering an annotation."""
|
@@ -171,45 +179,35 @@ class Annotation(QObject):
|
|
171
179
|
self.graphics_item = None
|
172
180
|
self.center_graphics_item = None
|
173
181
|
self.bounding_box_graphics_item = None
|
174
|
-
self.polygon_graphics_item = None
|
175
182
|
|
176
183
|
def create_graphics_item(self, scene: QGraphicsScene):
|
177
184
|
"""Create all graphics items for the annotation and add them to the scene as a group."""
|
178
185
|
# Remove old group if it exists
|
179
186
|
if self.graphics_item_group and self.graphics_item_group.scene():
|
180
187
|
self.graphics_item_group.scene().removeItem(self.graphics_item_group)
|
181
|
-
# Clear references to deleted items
|
182
188
|
self.center_graphics_item = None
|
183
189
|
self.bounding_box_graphics_item = None
|
184
|
-
self.polygon_graphics_item = None
|
185
190
|
self.graphics_item_group = QGraphicsItemGroup()
|
186
191
|
|
187
|
-
#
|
188
|
-
|
189
|
-
self.graphics_item
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
self.graphics_item.setData(0, self.id)
|
200
|
-
self.graphics_item_group.addToGroup(self.graphics_item)
|
192
|
+
# The subclass has already created self.graphics_item.
|
193
|
+
# This parent method is now only responsible for styling and grouping it.
|
194
|
+
if self.graphics_item:
|
195
|
+
color = QColor(self.label.color)
|
196
|
+
color.setAlpha(self.transparency)
|
197
|
+
self.graphics_item.setBrush(QBrush(color))
|
198
|
+
|
199
|
+
self.graphics_item.setPen(self._create_pen(color))
|
200
|
+
|
201
|
+
self.graphics_item.setData(0, self.id)
|
202
|
+
self.graphics_item_group.addToGroup(self.graphics_item)
|
201
203
|
|
202
|
-
# Create the
|
204
|
+
# Create and group the helper graphics (center, bbox, etc.)
|
203
205
|
self.create_center_graphics_item(self.center_xy, scene, add_to_group=True)
|
204
|
-
# Create the bounding box graphics item
|
205
206
|
self.create_bounding_box_graphics_item(self.get_bounding_box_top_left(),
|
206
207
|
self.get_bounding_box_bottom_right(),
|
207
208
|
scene, add_to_group=True)
|
208
|
-
|
209
|
-
|
210
|
-
self.create_polygon_graphics_item(points, scene, add_to_group=True)
|
211
|
-
|
212
|
-
# Add the group to the scene
|
209
|
+
|
210
|
+
# Add the final group to the scene
|
213
211
|
scene.addItem(self.graphics_item_group)
|
214
212
|
|
215
213
|
def set_visibility(self, visible):
|
@@ -289,8 +287,6 @@ class Annotation(QObject):
|
|
289
287
|
self.center_graphics_item.setPen(pen)
|
290
288
|
if self.bounding_box_graphics_item:
|
291
289
|
self.bounding_box_graphics_item.setPen(pen)
|
292
|
-
if self.polygon_graphics_item:
|
293
|
-
self.polygon_graphics_item.setPen(pen)
|
294
290
|
|
295
291
|
def create_center_graphics_item(self, center_xy, scene, add_to_group=False):
|
296
292
|
"""Create a graphical item representing the annotation's center point."""
|
@@ -345,32 +341,6 @@ class Annotation(QObject):
|
|
345
341
|
self.graphics_item_group.addToGroup(self.bounding_box_graphics_item)
|
346
342
|
else:
|
347
343
|
scene.addItem(self.bounding_box_graphics_item)
|
348
|
-
|
349
|
-
def create_polygon_graphics_item(self, points, scene, add_to_group=False):
|
350
|
-
"""Create a graphical item representing the annotation's polygon outline."""
|
351
|
-
try:
|
352
|
-
has_scene = self.polygon_graphics_item and self.polygon_graphics_item.scene()
|
353
|
-
except RuntimeError:
|
354
|
-
self.polygon_graphics_item = None
|
355
|
-
has_scene = False
|
356
|
-
|
357
|
-
if has_scene:
|
358
|
-
self.polygon_graphics_item.scene().removeItem(self.polygon_graphics_item)
|
359
|
-
|
360
|
-
color = QColor(self.label.color)
|
361
|
-
color.setAlpha(self.transparency)
|
362
|
-
|
363
|
-
polygon = QPolygonF(points)
|
364
|
-
self.polygon_graphics_item = QGraphicsPolygonItem(polygon)
|
365
|
-
self.polygon_graphics_item.setBrush(color)
|
366
|
-
|
367
|
-
# Use the consolidated pen creation method
|
368
|
-
self.polygon_graphics_item.setPen(self._create_pen(color))
|
369
|
-
|
370
|
-
if add_to_group and self.graphics_item_group:
|
371
|
-
self.graphics_item_group.addToGroup(self.polygon_graphics_item)
|
372
|
-
else:
|
373
|
-
scene.addItem(self.polygon_graphics_item)
|
374
344
|
|
375
345
|
def get_center_xy(self):
|
376
346
|
"""Get the center coordinates of the annotation."""
|
@@ -405,9 +375,11 @@ class Annotation(QObject):
|
|
405
375
|
try:
|
406
376
|
if self.graphics_item_group and self.graphics_item_group.scene():
|
407
377
|
scene = self.graphics_item_group.scene()
|
378
|
+
|
408
379
|
except RuntimeError:
|
409
380
|
self.graphics_item_group = None
|
410
381
|
scene = None
|
382
|
+
|
411
383
|
if scene is not None:
|
412
384
|
# Remove the old group from the scene
|
413
385
|
scene.removeItem(self.graphics_item_group)
|
@@ -440,19 +412,6 @@ class Annotation(QObject):
|
|
440
412
|
# Use the consolidated pen creation method
|
441
413
|
self.bounding_box_graphics_item.setPen(self._create_pen(color))
|
442
414
|
|
443
|
-
def update_polygon_graphics_item(self, points):
|
444
|
-
"""Update the shape and appearance of the polygon graphics item."""
|
445
|
-
if self.polygon_graphics_item:
|
446
|
-
color = QColor(self.label.color)
|
447
|
-
color.setAlpha(self.transparency)
|
448
|
-
|
449
|
-
polygon = QPolygonF(points)
|
450
|
-
self.polygon_graphics_item.setPolygon(polygon)
|
451
|
-
self.polygon_graphics_item.setBrush(color)
|
452
|
-
|
453
|
-
# Use the consolidated pen creation method
|
454
|
-
self.polygon_graphics_item.setPen(self._create_pen(color))
|
455
|
-
|
456
415
|
def update_transparency(self, transparency: int):
|
457
416
|
"""Update the transparency value of the annotation's graphical representation."""
|
458
417
|
if self.transparency != transparency:
|
@@ -0,0 +1,408 @@
|
|
1
|
+
import warnings
|
2
|
+
|
3
|
+
import zlib
|
4
|
+
import base64
|
5
|
+
import rasterio
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from scipy.ndimage import label as ndimage_label
|
10
|
+
from skimage.measure import find_contours
|
11
|
+
|
12
|
+
from PyQt5.QtCore import Qt, QPointF, QRectF, QPolygonF
|
13
|
+
from PyQt5.QtWidgets import QGraphicsScene, QGraphicsPixmapItem
|
14
|
+
from PyQt5.QtGui import QPixmap, QColor, QImage, QPainter, QBrush
|
15
|
+
|
16
|
+
from coralnet_toolbox.Annotations.QtAnnotation import Annotation
|
17
|
+
from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
|
18
|
+
|
19
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
20
|
+
|
21
|
+
|
22
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
23
|
+
# Helper Functions for Serialization
|
24
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
25
|
+
|
26
|
+
|
27
|
+
def rle_encode(mask):
|
28
|
+
"""
|
29
|
+
Encodes a 2D numpy array using Run-Length Encoding.
|
30
|
+
Returns a compressed string representation.
|
31
|
+
"""
|
32
|
+
pixels = mask.flatten()
|
33
|
+
pixels = np.append(pixels, -1) # Append a sentinel value
|
34
|
+
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
|
35
|
+
runs[1:] = runs[1:] - runs[:-1]
|
36
|
+
values = pixels[np.cumsum(np.append(0, runs[:-1]))]
|
37
|
+
|
38
|
+
# Pair values and runs, then convert to a string
|
39
|
+
rle_pairs = ",".join([f"{v},{r}" for v, r in zip(values, runs)])
|
40
|
+
|
41
|
+
# Further compress with zlib and encode in base64 for JSON compatibility
|
42
|
+
compressed = zlib.compress(rle_pairs.encode('utf-8'))
|
43
|
+
return base64.b64encode(compressed).decode('ascii')
|
44
|
+
|
45
|
+
|
46
|
+
def rle_decode(rle_string, shape):
|
47
|
+
"""
|
48
|
+
Decodes a Run-Length Encoded string back into a 2D numpy array.
|
49
|
+
"""
|
50
|
+
# Decode from base64 and decompress with zlib
|
51
|
+
decoded_b64 = base64.b64decode(rle_string)
|
52
|
+
decompressed = zlib.decompress(decoded_b64).decode('utf-8')
|
53
|
+
|
54
|
+
# Parse the value,run pairs
|
55
|
+
pairs = decompressed.split(',')
|
56
|
+
values = [int(v) for v in pairs[0::2]]
|
57
|
+
runs = [int(r) for r in pairs[1::2]]
|
58
|
+
|
59
|
+
# Reconstruct the pixel array
|
60
|
+
pixels = np.repeat(values, runs)
|
61
|
+
|
62
|
+
# Reshape to the original 2D mask dimensions
|
63
|
+
return pixels.reshape(shape)
|
64
|
+
|
65
|
+
|
66
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
67
|
+
# Classes
|
68
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
69
|
+
|
70
|
+
|
71
|
+
class MaskAnnotation(Annotation):
|
72
|
+
def __init__(self,
|
73
|
+
image_path: str,
|
74
|
+
mask_data: np.ndarray,
|
75
|
+
label_map: dict,
|
76
|
+
transparency: int = 128,
|
77
|
+
show_msg: bool = False,
|
78
|
+
rasterio_src=None):
|
79
|
+
"""
|
80
|
+
Initialize a full-image semantic segmentation annotation.
|
81
|
+
There should only be one MaskAnnotation per image.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
image_path (str): Path to the source image.
|
85
|
+
mask_data (np.ndarray): 2D numpy array of integer class IDs, matching image dimensions.
|
86
|
+
label_map (dict): A map of {class_id: Label_object}.
|
87
|
+
transparency (int): The alpha value for displaying the mask overlay.
|
88
|
+
rasterio_src: Optional rasterio dataset object for the source image.
|
89
|
+
"""
|
90
|
+
# For a full-image mask, the concept of a single "primary label" is ambiguous.
|
91
|
+
# We'll use the first available label as a placeholder to satisfy the base class.
|
92
|
+
if not label_map:
|
93
|
+
raise ValueError("label_map cannot be empty.")
|
94
|
+
placeholder_label = next(iter(label_map.values()))
|
95
|
+
|
96
|
+
super().__init__(
|
97
|
+
short_label_code=placeholder_label.short_label_code,
|
98
|
+
long_label_code=placeholder_label.long_label_code,
|
99
|
+
color=placeholder_label.color,
|
100
|
+
image_path=image_path,
|
101
|
+
label_id=placeholder_label.id,
|
102
|
+
transparency=transparency,
|
103
|
+
show_msg=show_msg
|
104
|
+
)
|
105
|
+
|
106
|
+
self.mask_data = mask_data
|
107
|
+
self.label_map = label_map
|
108
|
+
self.offset = QPointF(0, 0) # A full-image mask has no offset
|
109
|
+
self.rasterio_src = rasterio_src # Store rasterio source if provided
|
110
|
+
|
111
|
+
# Set geometric properties from mask
|
112
|
+
self.set_centroid()
|
113
|
+
self.set_cropped_bbox()
|
114
|
+
|
115
|
+
def set_centroid(self):
|
116
|
+
"""Set the centroid to the center of the image."""
|
117
|
+
height, width = self.mask_data.shape
|
118
|
+
self.center_xy = QPointF(width / 2.0, height / 2.0)
|
119
|
+
|
120
|
+
def set_cropped_bbox(self):
|
121
|
+
"""Set the bounding box to the full dimensions of the image."""
|
122
|
+
height, width = self.mask_data.shape
|
123
|
+
self.cropped_bbox = (0, 0, width, height)
|
124
|
+
self.annotation_size = int(max(width, height))
|
125
|
+
|
126
|
+
def _render_mask_to_pixmap(self) -> QPixmap:
|
127
|
+
"""Converts the numpy mask_data into a colored QPixmap for display."""
|
128
|
+
height, width = self.mask_data.shape
|
129
|
+
|
130
|
+
# Create a color map for fast lookup from class ID to RGBA color
|
131
|
+
max_id = max(self.label_map.keys()) if self.label_map else 0
|
132
|
+
color_map = np.zeros((max_id + 1, 4), dtype=np.uint8)
|
133
|
+
for class_id, label in self.label_map.items():
|
134
|
+
color = label.color
|
135
|
+
color_map[class_id] = [color.red(), color.green(), color.blue(), self.transparency]
|
136
|
+
|
137
|
+
colored_mask = color_map[self.mask_data]
|
138
|
+
q_image = QImage(colored_mask.data, width, height, QImage.Format_RGBA8888)
|
139
|
+
|
140
|
+
return QPixmap.fromImage(q_image)
|
141
|
+
|
142
|
+
def contains_point(self, point: QPointF) -> bool:
|
143
|
+
"""Check if a point is within the mask's classified area."""
|
144
|
+
x, y = int(point.x()), int(point.y())
|
145
|
+
height, width = self.mask_data.shape
|
146
|
+
if 0 <= y < height and 0 <= x < width:
|
147
|
+
return self.mask_data[y, x] > 0
|
148
|
+
return False
|
149
|
+
|
150
|
+
def get_area(self):
|
151
|
+
"""Return the total number of non-background pixels."""
|
152
|
+
return np.count_nonzero(self.mask_data)
|
153
|
+
|
154
|
+
def get_bounding_box_top_left(self):
|
155
|
+
"""Get the top-left corner of the annotation's bounding box (always 0,0)."""
|
156
|
+
return QPointF(0, 0)
|
157
|
+
|
158
|
+
def get_bounding_box_bottom_right(self):
|
159
|
+
"""Get the bottom-right corner of the annotation's bounding box."""
|
160
|
+
height, width = self.mask_data.shape
|
161
|
+
return QPointF(width, height)
|
162
|
+
|
163
|
+
def create_graphics_item(self, scene: QGraphicsScene):
|
164
|
+
"""Create a QGraphicsPixmapItem to display the mask."""
|
165
|
+
pixmap = self._render_mask_to_pixmap()
|
166
|
+
self.graphics_item = QGraphicsPixmapItem(pixmap)
|
167
|
+
self.graphics_item.setPos(self.offset)
|
168
|
+
super().create_graphics_item(scene)
|
169
|
+
|
170
|
+
def update_graphics_item(self):
|
171
|
+
"""Update the pixmap if the mask data has changed."""
|
172
|
+
if self.graphics_item:
|
173
|
+
pixmap = self._render_mask_to_pixmap()
|
174
|
+
self.graphics_item.setPixmap(pixmap)
|
175
|
+
super().update_graphics_item()
|
176
|
+
|
177
|
+
def update_mask(self, brush_location: QPointF, brush_mask: np.ndarray, new_class_id: int):
|
178
|
+
"""Modify the mask data based on a brush stroke."""
|
179
|
+
x_start, y_start = int(brush_location.x()), int(brush_location.y())
|
180
|
+
brush_h, brush_w = brush_mask.shape
|
181
|
+
mask_h, mask_w = self.mask_data.shape
|
182
|
+
|
183
|
+
x_end = min(x_start + brush_w, mask_w)
|
184
|
+
y_end = min(y_start + brush_h, mask_h)
|
185
|
+
clipped_x_start = max(x_start, 0)
|
186
|
+
clipped_y_start = max(y_start, 0)
|
187
|
+
|
188
|
+
if clipped_x_start >= x_end or clipped_y_start >= y_end:
|
189
|
+
return
|
190
|
+
|
191
|
+
target_slice = self.mask_data[clipped_y_start:y_end, clipped_x_start:x_end]
|
192
|
+
brush_x_offset = clipped_x_start - x_start
|
193
|
+
brush_y_offset = clipped_y_start - y_start
|
194
|
+
clipped_brush_mask = brush_mask[brush_y_offset:brush_y_offset + target_slice.shape[0],
|
195
|
+
brush_x_offset:brush_x_offset + target_slice.shape[1]]
|
196
|
+
|
197
|
+
target_slice[clipped_brush_mask] = new_class_id
|
198
|
+
|
199
|
+
self.update_graphics_item()
|
200
|
+
self.annotationUpdated.emit(self)
|
201
|
+
|
202
|
+
# --- Data Manipulation & Editing Methods ---
|
203
|
+
|
204
|
+
def fill_region(self, point: QPointF, new_class_id: int):
|
205
|
+
"""Fills a contiguous region with a new class ID (paint bucket tool)."""
|
206
|
+
x, y = int(point.x()), int(point.y())
|
207
|
+
height, width = self.mask_data.shape
|
208
|
+
if not (0 <= y < height and 0 <= x < width):
|
209
|
+
return
|
210
|
+
|
211
|
+
old_class_id = self.mask_data[y, x]
|
212
|
+
if old_class_id == new_class_id:
|
213
|
+
return
|
214
|
+
|
215
|
+
labeled_array, num_features = ndimage_label(self.mask_data == old_class_id)
|
216
|
+
region_label = labeled_array[y, x]
|
217
|
+
|
218
|
+
self.mask_data[labeled_array == region_label] = new_class_id
|
219
|
+
|
220
|
+
self.update_graphics_item()
|
221
|
+
self.annotationUpdated.emit(self)
|
222
|
+
|
223
|
+
def replace_class(self, old_class_id: int, new_class_id: int):
|
224
|
+
"""Replaces all pixels of one class ID with another across the entire mask."""
|
225
|
+
self.mask_data[self.mask_data == old_class_id] = new_class_id
|
226
|
+
self.update_graphics_item()
|
227
|
+
self.annotationUpdated.emit(self)
|
228
|
+
|
229
|
+
def import_annotations(self, annotations: list, class_id: int):
|
230
|
+
"""Burns a list of vector annotations into the mask with a given class ID."""
|
231
|
+
height, width = self.mask_data.shape
|
232
|
+
# Use Format_Alpha8 for an efficient 8-bit stencil mask
|
233
|
+
stencil = QImage(width, height, QImage.Format_Alpha8)
|
234
|
+
stencil.fill(Qt.transparent)
|
235
|
+
|
236
|
+
painter = QPainter(stencil)
|
237
|
+
painter.setPen(Qt.NoPen)
|
238
|
+
painter.setBrush(QBrush(QColor("white"))) # Opaque white
|
239
|
+
|
240
|
+
for anno in annotations:
|
241
|
+
path = anno.get_painter_path()
|
242
|
+
painter.drawPath(path)
|
243
|
+
painter.end()
|
244
|
+
|
245
|
+
# Convert the QImage stencil to a boolean numpy array
|
246
|
+
ptr = stencil.bits()
|
247
|
+
ptr.setsize(stencil.byteCount())
|
248
|
+
stencil_np = np.array(ptr).reshape(height, width) > 0 # True where painted
|
249
|
+
|
250
|
+
# Update the mask data where the stencil is True
|
251
|
+
self.mask_data[stencil_np] = class_id
|
252
|
+
|
253
|
+
self.update_graphics_item()
|
254
|
+
self.annotationUpdated.emit(self)
|
255
|
+
|
256
|
+
# --- Analysis & Information Retrieval Methods ---
|
257
|
+
|
258
|
+
def get_class_statistics(self) -> dict:
|
259
|
+
"""Returns a dictionary with pixel counts and percentages for each class."""
|
260
|
+
stats = {}
|
261
|
+
total_pixels = self.get_area()
|
262
|
+
if total_pixels == 0:
|
263
|
+
return stats
|
264
|
+
|
265
|
+
class_ids, counts = np.unique(self.mask_data[self.mask_data > 0], return_counts=True)
|
266
|
+
|
267
|
+
for cid, count in zip(class_ids, counts):
|
268
|
+
label = self.label_map.get(int(cid))
|
269
|
+
if label:
|
270
|
+
stats[label.short_label_code] = {
|
271
|
+
"pixel_count": int(count),
|
272
|
+
"percentage": (count / total_pixels) * 100
|
273
|
+
}
|
274
|
+
return stats
|
275
|
+
|
276
|
+
def get_class_at_point(self, point: QPointF) -> int:
|
277
|
+
"""Returns the class ID at a specific point."""
|
278
|
+
x, y = int(point.x()), int(point.y())
|
279
|
+
height, width = self.mask_data.shape
|
280
|
+
if 0 <= y < height and 0 <= x < width:
|
281
|
+
return self.mask_data[y, x]
|
282
|
+
return 0 # Return background class if outside bounds
|
283
|
+
|
284
|
+
# --- Conversion & Exporting Methods ---
|
285
|
+
|
286
|
+
def get_binary_mask(self, class_id: int) -> np.ndarray:
|
287
|
+
"""Returns a boolean numpy array where True corresponds to the given class ID."""
|
288
|
+
return self.mask_data == class_id
|
289
|
+
|
290
|
+
def to_instance_polygons(self, class_id: int) -> list:
|
291
|
+
"""Converts all contiguous regions of a class ID into PolygonAnnotations."""
|
292
|
+
binary_mask = self.get_binary_mask(class_id)
|
293
|
+
# Add padding to handle contours touching the border
|
294
|
+
padded_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
|
295
|
+
|
296
|
+
# Level is 0.5 to find contours between 0 and 1 values
|
297
|
+
contours = find_contours(padded_mask, level=0.5)
|
298
|
+
|
299
|
+
annotations = []
|
300
|
+
for contour in contours:
|
301
|
+
# Remove padding offset and swap (row, col) to (x, y)
|
302
|
+
points = [QPointF(p[1] - 1, p[0] - 1) for p in contour]
|
303
|
+
if len(points) > 2: # Must have at least 3 points for a valid polygon
|
304
|
+
# Use the label associated with the class_id for the new annotation
|
305
|
+
label = self.label_map[class_id]
|
306
|
+
anno = PolygonAnnotation(
|
307
|
+
points=points,
|
308
|
+
short_label_code=label.short_label_code,
|
309
|
+
long_label_code=label.long_label_code,
|
310
|
+
color=label.color,
|
311
|
+
image_path=self.image_path,
|
312
|
+
label_id=label.id
|
313
|
+
)
|
314
|
+
annotations.append(anno)
|
315
|
+
return annotations
|
316
|
+
|
317
|
+
def export_as_png(self, path: str, use_label_colors: bool = True):
|
318
|
+
"""Saves the mask to a PNG file."""
|
319
|
+
if use_label_colors:
|
320
|
+
# Use rendering logic to create a colored image
|
321
|
+
pixmap = self._render_mask_to_pixmap()
|
322
|
+
pixmap.toImage().save(path)
|
323
|
+
else:
|
324
|
+
# Save the raw class IDs as a grayscale image
|
325
|
+
height, width = self.mask_data.shape
|
326
|
+
# Ensure data is in a format QImage can handle (e.g., 8-bit grayscale)
|
327
|
+
if self.mask_data.max() < 256:
|
328
|
+
img_data = self.mask_data.astype(np.uint8)
|
329
|
+
q_image = QImage(img_data.data, width, height, QImage.Format_Grayscale8)
|
330
|
+
q_image.save(path)
|
331
|
+
else:
|
332
|
+
warnings.warn("Mask contains class IDs > 255; cannot save as 8-bit grayscale PNG.")
|
333
|
+
|
334
|
+
def export_as_raster(self, path: str):
|
335
|
+
"""Saves the mask data to a raster file (e.g., GeoTIFF) using rasterio."""
|
336
|
+
profile = {
|
337
|
+
'driver': 'GTiff',
|
338
|
+
'height': self.mask_data.shape[0],
|
339
|
+
'width': self.mask_data.shape[1],
|
340
|
+
'count': 1,
|
341
|
+
'dtype': self.mask_data.dtype
|
342
|
+
}
|
343
|
+
|
344
|
+
# If the original image was opened with rasterio, copy its spatial metadata
|
345
|
+
if self.rasterio_src:
|
346
|
+
profile['crs'] = self.rasterio_src.crs
|
347
|
+
profile['transform'] = self.rasterio_src.transform
|
348
|
+
|
349
|
+
with rasterio.open(path, 'w', **profile) as dst:
|
350
|
+
dst.write(self.mask_data, 1)
|
351
|
+
|
352
|
+
# --- Serialization & Deserialization ---
|
353
|
+
|
354
|
+
def to_dict(self):
|
355
|
+
"""Serialize the annotation to a dictionary, with RLE for the mask."""
|
356
|
+
base_dict = super().to_dict()
|
357
|
+
rle_string = rle_encode(self.mask_data)
|
358
|
+
base_dict.update({
|
359
|
+
'shape': self.mask_data.shape,
|
360
|
+
'rle_mask': rle_string,
|
361
|
+
'label_map': {cid: label.short_label_code for cid, label in self.label_map.items()}
|
362
|
+
})
|
363
|
+
return base_dict
|
364
|
+
|
365
|
+
@classmethod
|
366
|
+
def from_dict(cls, data, label_window):
|
367
|
+
"""Instantiate a MaskAnnotation from a dictionary."""
|
368
|
+
label_map = {}
|
369
|
+
for cid_str, short_code in data['label_map'].items():
|
370
|
+
label = label_window.get_label_by_short_code(short_code)
|
371
|
+
if label:
|
372
|
+
label_map[int(cid_str)] = label
|
373
|
+
|
374
|
+
mask_data = rle_decode(data['rle_mask'], data['shape'])
|
375
|
+
|
376
|
+
annotation = cls(
|
377
|
+
image_path=data['image_path'],
|
378
|
+
mask_data=mask_data,
|
379
|
+
label_map=label_map
|
380
|
+
)
|
381
|
+
annotation.id = data.get('id', annotation.id)
|
382
|
+
annotation.data = data.get('data', {})
|
383
|
+
return annotation
|
384
|
+
|
385
|
+
@classmethod
|
386
|
+
def from_rasterio(cls, file_path: str, image_path: str, label_map: dict):
|
387
|
+
"""Creates a MaskAnnotation instance by loading data from a raster file."""
|
388
|
+
with rasterio.open(file_path) as src:
|
389
|
+
mask_data = src.read(1)
|
390
|
+
return cls(
|
391
|
+
image_path=image_path,
|
392
|
+
mask_data=mask_data,
|
393
|
+
label_map=label_map,
|
394
|
+
rasterio_src=src
|
395
|
+
)
|
396
|
+
|
397
|
+
# --- Compatibility Methods ---
|
398
|
+
def get_perimeter(self):
|
399
|
+
height, width = self.mask_data.shape
|
400
|
+
return 2 * (width + height)
|
401
|
+
|
402
|
+
def get_polygon(self):
|
403
|
+
height, width = self.mask_data.shape
|
404
|
+
return QPolygonF(QRectF(0, 0, width, height))
|
405
|
+
|
406
|
+
def __repr__(self):
|
407
|
+
return (f"MaskAnnotation(id={self.id}, image_path={self.image_path}, "
|
408
|
+
f"shape={self.mask_data.shape})")
|