coralnet-toolbox 0.0.73__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
  2. coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
  3. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
  4. coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
  5. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
  6. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
  7. coralnet_toolbox/CoralNet/QtDownload.py +2 -1
  8. coralnet_toolbox/Explorer/QtExplorer.py +16 -14
  9. coralnet_toolbox/Explorer/QtSettingsWidgets.py +114 -82
  10. coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
  11. coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
  12. coralnet_toolbox/IO/QtOpenProject.py +46 -78
  13. coralnet_toolbox/IO/QtSaveProject.py +18 -43
  14. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +1 -1
  15. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
  16. coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
  17. coralnet_toolbox/QtEventFilter.py +11 -0
  18. coralnet_toolbox/QtImageWindow.py +117 -68
  19. coralnet_toolbox/QtLabelWindow.py +13 -1
  20. coralnet_toolbox/QtMainWindow.py +5 -27
  21. coralnet_toolbox/QtProgressBar.py +52 -27
  22. coralnet_toolbox/Rasters/RasterTableModel.py +8 -8
  23. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  24. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +779 -161
  25. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +86 -149
  26. coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
  27. coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
  28. coralnet_toolbox/Tools/QtSAMTool.py +72 -50
  29. coralnet_toolbox/Tools/QtSeeAnythingTool.py +8 -5
  30. coralnet_toolbox/Tools/QtSelectTool.py +27 -3
  31. coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
  32. coralnet_toolbox/Tools/__init__.py +2 -0
  33. coralnet_toolbox/__init__.py +1 -1
  34. coralnet_toolbox/utilities.py +137 -47
  35. coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
  36. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +40 -38
  37. coralnet_toolbox-0.0.73.dist-info/METADATA +0 -341
  38. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
  39. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
  40. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
  41. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,8 @@ import warnings
4
4
 
5
5
  import numpy as np
6
6
 
7
- from PyQt5.QtCore import Qt, pyqtSignal, QObject, QPointF, QTimer, pyqtProperty
8
7
  from PyQt5.QtGui import QColor, QPolygonF, QPen, QBrush
8
+ from PyQt5.QtCore import Qt, pyqtSignal, QObject, QPointF, QTimer, pyqtProperty
9
9
  from PyQt5.QtWidgets import (QMessageBox, QGraphicsEllipseItem, QGraphicsRectItem,
10
10
  QGraphicsPolygonItem, QGraphicsScene, QGraphicsItemGroup)
11
11
 
@@ -54,12 +54,11 @@ class Annotation(QObject):
54
54
  self.annotation_size = None
55
55
  self.tolerance = 0.1 # Default detail level for simplification/densification
56
56
 
57
- # Attributes to store the graphics items for center/centroid, bounding box, and polygon
57
+ # Attributes to store the graphics items for center/centroid and bounding box
58
58
  self.center_graphics_item = None
59
59
  self.bounding_box_graphics_item = None
60
- self.polygon_graphics_item = None
61
60
 
62
- # New: group for all graphics items
61
+ # Group for all graphics items
63
62
  self.graphics_item_group = None
64
63
 
65
64
  # Animation properties
@@ -87,6 +86,10 @@ class Annotation(QObject):
87
86
  def get_polygon(self):
88
87
  """Get the polygon representation of this annotation."""
89
88
  raise NotImplementedError("Subclasses must implement this method.")
89
+
90
+ def get_painter_path(self):
91
+ """Get the QPainterPath representation of this annotation."""
92
+ raise NotImplementedError("Subclasses must implement this method.")
90
93
 
91
94
  def get_bounding_box_top_left(self):
92
95
  """Get the top-left corner of the annotation's bounding box."""
@@ -125,6 +128,11 @@ class Annotation(QObject):
125
128
  def cut(cls, annotations: list, cutting_points: list):
126
129
  """Cut multiple annotations using specified cutting points."""
127
130
  raise NotImplementedError("Subclasses must implement this method.")
131
+
132
+ @classmethod
133
+ def subtract(cls, base_annotation, cutter_annotations: list):
134
+ """Subtract cutter annotations from a base annotation."""
135
+ raise NotImplementedError("Subclasses must implement this method.")
128
136
 
129
137
  def show_warning_message(self):
130
138
  """Display a warning message about removing machine suggestions when altering an annotation."""
@@ -171,45 +179,35 @@ class Annotation(QObject):
171
179
  self.graphics_item = None
172
180
  self.center_graphics_item = None
173
181
  self.bounding_box_graphics_item = None
174
- self.polygon_graphics_item = None
175
182
 
176
183
  def create_graphics_item(self, scene: QGraphicsScene):
177
184
  """Create all graphics items for the annotation and add them to the scene as a group."""
178
185
  # Remove old group if it exists
179
186
  if self.graphics_item_group and self.graphics_item_group.scene():
180
187
  self.graphics_item_group.scene().removeItem(self.graphics_item_group)
181
- # Clear references to deleted items
182
188
  self.center_graphics_item = None
183
189
  self.bounding_box_graphics_item = None
184
- self.polygon_graphics_item = None
185
190
  self.graphics_item_group = QGraphicsItemGroup()
186
191
 
187
- # Create the main graphics item based on the polygon
188
- polygon = self.get_polygon()
189
- self.graphics_item = QGraphicsPolygonItem(polygon)
190
-
191
- # Style the main graphics item with color and pen
192
- color = QColor(self.label.color)
193
- color.setAlpha(self.transparency)
194
- self.graphics_item.setBrush(QBrush(color))
195
-
196
- # Use the consolidated pen creation method
197
- self.graphics_item.setPen(self._create_pen(color))
198
-
199
- self.graphics_item.setData(0, self.id)
200
- self.graphics_item_group.addToGroup(self.graphics_item)
192
+ # The subclass has already created self.graphics_item.
193
+ # This parent method is now only responsible for styling and grouping it.
194
+ if self.graphics_item:
195
+ color = QColor(self.label.color)
196
+ color.setAlpha(self.transparency)
197
+ self.graphics_item.setBrush(QBrush(color))
198
+
199
+ self.graphics_item.setPen(self._create_pen(color))
200
+
201
+ self.graphics_item.setData(0, self.id)
202
+ self.graphics_item_group.addToGroup(self.graphics_item)
201
203
 
202
- # Create the center graphics item
204
+ # Create and group the helper graphics (center, bbox, etc.)
203
205
  self.create_center_graphics_item(self.center_xy, scene, add_to_group=True)
204
- # Create the bounding box graphics item
205
206
  self.create_bounding_box_graphics_item(self.get_bounding_box_top_left(),
206
207
  self.get_bounding_box_bottom_right(),
207
208
  scene, add_to_group=True)
208
- # Create the polygon graphics item
209
- points = [polygon.at(i) for i in range(polygon.count())]
210
- self.create_polygon_graphics_item(points, scene, add_to_group=True)
211
-
212
- # Add the group to the scene
209
+
210
+ # Add the final group to the scene
213
211
  scene.addItem(self.graphics_item_group)
214
212
 
215
213
  def set_visibility(self, visible):
@@ -289,8 +287,6 @@ class Annotation(QObject):
289
287
  self.center_graphics_item.setPen(pen)
290
288
  if self.bounding_box_graphics_item:
291
289
  self.bounding_box_graphics_item.setPen(pen)
292
- if self.polygon_graphics_item:
293
- self.polygon_graphics_item.setPen(pen)
294
290
 
295
291
  def create_center_graphics_item(self, center_xy, scene, add_to_group=False):
296
292
  """Create a graphical item representing the annotation's center point."""
@@ -345,32 +341,6 @@ class Annotation(QObject):
345
341
  self.graphics_item_group.addToGroup(self.bounding_box_graphics_item)
346
342
  else:
347
343
  scene.addItem(self.bounding_box_graphics_item)
348
-
349
- def create_polygon_graphics_item(self, points, scene, add_to_group=False):
350
- """Create a graphical item representing the annotation's polygon outline."""
351
- try:
352
- has_scene = self.polygon_graphics_item and self.polygon_graphics_item.scene()
353
- except RuntimeError:
354
- self.polygon_graphics_item = None
355
- has_scene = False
356
-
357
- if has_scene:
358
- self.polygon_graphics_item.scene().removeItem(self.polygon_graphics_item)
359
-
360
- color = QColor(self.label.color)
361
- color.setAlpha(self.transparency)
362
-
363
- polygon = QPolygonF(points)
364
- self.polygon_graphics_item = QGraphicsPolygonItem(polygon)
365
- self.polygon_graphics_item.setBrush(color)
366
-
367
- # Use the consolidated pen creation method
368
- self.polygon_graphics_item.setPen(self._create_pen(color))
369
-
370
- if add_to_group and self.graphics_item_group:
371
- self.graphics_item_group.addToGroup(self.polygon_graphics_item)
372
- else:
373
- scene.addItem(self.polygon_graphics_item)
374
344
 
375
345
  def get_center_xy(self):
376
346
  """Get the center coordinates of the annotation."""
@@ -405,9 +375,11 @@ class Annotation(QObject):
405
375
  try:
406
376
  if self.graphics_item_group and self.graphics_item_group.scene():
407
377
  scene = self.graphics_item_group.scene()
378
+
408
379
  except RuntimeError:
409
380
  self.graphics_item_group = None
410
381
  scene = None
382
+
411
383
  if scene is not None:
412
384
  # Remove the old group from the scene
413
385
  scene.removeItem(self.graphics_item_group)
@@ -440,19 +412,6 @@ class Annotation(QObject):
440
412
  # Use the consolidated pen creation method
441
413
  self.bounding_box_graphics_item.setPen(self._create_pen(color))
442
414
 
443
- def update_polygon_graphics_item(self, points):
444
- """Update the shape and appearance of the polygon graphics item."""
445
- if self.polygon_graphics_item:
446
- color = QColor(self.label.color)
447
- color.setAlpha(self.transparency)
448
-
449
- polygon = QPolygonF(points)
450
- self.polygon_graphics_item.setPolygon(polygon)
451
- self.polygon_graphics_item.setBrush(color)
452
-
453
- # Use the consolidated pen creation method
454
- self.polygon_graphics_item.setPen(self._create_pen(color))
455
-
456
415
  def update_transparency(self, transparency: int):
457
416
  """Update the transparency value of the annotation's graphical representation."""
458
417
  if self.transparency != transparency:
@@ -0,0 +1,408 @@
1
+ import warnings
2
+
3
+ import zlib
4
+ import base64
5
+ import rasterio
6
+
7
+ import numpy as np
8
+
9
+ from scipy.ndimage import label as ndimage_label
10
+ from skimage.measure import find_contours
11
+
12
+ from PyQt5.QtCore import Qt, QPointF, QRectF, QPolygonF
13
+ from PyQt5.QtWidgets import QGraphicsScene, QGraphicsPixmapItem
14
+ from PyQt5.QtGui import QPixmap, QColor, QImage, QPainter, QBrush
15
+
16
+ from coralnet_toolbox.Annotations.QtAnnotation import Annotation
17
+ from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
18
+
19
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
20
+
21
+
22
+ # ----------------------------------------------------------------------------------------------------------------------
23
+ # Helper Functions for Serialization
24
+ # ----------------------------------------------------------------------------------------------------------------------
25
+
26
+
27
+ def rle_encode(mask):
28
+ """
29
+ Encodes a 2D numpy array using Run-Length Encoding.
30
+ Returns a compressed string representation.
31
+ """
32
+ pixels = mask.flatten()
33
+ pixels = np.append(pixels, -1) # Append a sentinel value
34
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
35
+ runs[1:] = runs[1:] - runs[:-1]
36
+ values = pixels[np.cumsum(np.append(0, runs[:-1]))]
37
+
38
+ # Pair values and runs, then convert to a string
39
+ rle_pairs = ",".join([f"{v},{r}" for v, r in zip(values, runs)])
40
+
41
+ # Further compress with zlib and encode in base64 for JSON compatibility
42
+ compressed = zlib.compress(rle_pairs.encode('utf-8'))
43
+ return base64.b64encode(compressed).decode('ascii')
44
+
45
+
46
+ def rle_decode(rle_string, shape):
47
+ """
48
+ Decodes a Run-Length Encoded string back into a 2D numpy array.
49
+ """
50
+ # Decode from base64 and decompress with zlib
51
+ decoded_b64 = base64.b64decode(rle_string)
52
+ decompressed = zlib.decompress(decoded_b64).decode('utf-8')
53
+
54
+ # Parse the value,run pairs
55
+ pairs = decompressed.split(',')
56
+ values = [int(v) for v in pairs[0::2]]
57
+ runs = [int(r) for r in pairs[1::2]]
58
+
59
+ # Reconstruct the pixel array
60
+ pixels = np.repeat(values, runs)
61
+
62
+ # Reshape to the original 2D mask dimensions
63
+ return pixels.reshape(shape)
64
+
65
+
66
+ # ----------------------------------------------------------------------------------------------------------------------
67
+ # Classes
68
+ # ----------------------------------------------------------------------------------------------------------------------
69
+
70
+
71
+ class MaskAnnotation(Annotation):
72
+ def __init__(self,
73
+ image_path: str,
74
+ mask_data: np.ndarray,
75
+ label_map: dict,
76
+ transparency: int = 128,
77
+ show_msg: bool = False,
78
+ rasterio_src=None):
79
+ """
80
+ Initialize a full-image semantic segmentation annotation.
81
+ There should only be one MaskAnnotation per image.
82
+
83
+ Args:
84
+ image_path (str): Path to the source image.
85
+ mask_data (np.ndarray): 2D numpy array of integer class IDs, matching image dimensions.
86
+ label_map (dict): A map of {class_id: Label_object}.
87
+ transparency (int): The alpha value for displaying the mask overlay.
88
+ rasterio_src: Optional rasterio dataset object for the source image.
89
+ """
90
+ # For a full-image mask, the concept of a single "primary label" is ambiguous.
91
+ # We'll use the first available label as a placeholder to satisfy the base class.
92
+ if not label_map:
93
+ raise ValueError("label_map cannot be empty.")
94
+ placeholder_label = next(iter(label_map.values()))
95
+
96
+ super().__init__(
97
+ short_label_code=placeholder_label.short_label_code,
98
+ long_label_code=placeholder_label.long_label_code,
99
+ color=placeholder_label.color,
100
+ image_path=image_path,
101
+ label_id=placeholder_label.id,
102
+ transparency=transparency,
103
+ show_msg=show_msg
104
+ )
105
+
106
+ self.mask_data = mask_data
107
+ self.label_map = label_map
108
+ self.offset = QPointF(0, 0) # A full-image mask has no offset
109
+ self.rasterio_src = rasterio_src # Store rasterio source if provided
110
+
111
+ # Set geometric properties from mask
112
+ self.set_centroid()
113
+ self.set_cropped_bbox()
114
+
115
+ def set_centroid(self):
116
+ """Set the centroid to the center of the image."""
117
+ height, width = self.mask_data.shape
118
+ self.center_xy = QPointF(width / 2.0, height / 2.0)
119
+
120
+ def set_cropped_bbox(self):
121
+ """Set the bounding box to the full dimensions of the image."""
122
+ height, width = self.mask_data.shape
123
+ self.cropped_bbox = (0, 0, width, height)
124
+ self.annotation_size = int(max(width, height))
125
+
126
+ def _render_mask_to_pixmap(self) -> QPixmap:
127
+ """Converts the numpy mask_data into a colored QPixmap for display."""
128
+ height, width = self.mask_data.shape
129
+
130
+ # Create a color map for fast lookup from class ID to RGBA color
131
+ max_id = max(self.label_map.keys()) if self.label_map else 0
132
+ color_map = np.zeros((max_id + 1, 4), dtype=np.uint8)
133
+ for class_id, label in self.label_map.items():
134
+ color = label.color
135
+ color_map[class_id] = [color.red(), color.green(), color.blue(), self.transparency]
136
+
137
+ colored_mask = color_map[self.mask_data]
138
+ q_image = QImage(colored_mask.data, width, height, QImage.Format_RGBA8888)
139
+
140
+ return QPixmap.fromImage(q_image)
141
+
142
+ def contains_point(self, point: QPointF) -> bool:
143
+ """Check if a point is within the mask's classified area."""
144
+ x, y = int(point.x()), int(point.y())
145
+ height, width = self.mask_data.shape
146
+ if 0 <= y < height and 0 <= x < width:
147
+ return self.mask_data[y, x] > 0
148
+ return False
149
+
150
+ def get_area(self):
151
+ """Return the total number of non-background pixels."""
152
+ return np.count_nonzero(self.mask_data)
153
+
154
+ def get_bounding_box_top_left(self):
155
+ """Get the top-left corner of the annotation's bounding box (always 0,0)."""
156
+ return QPointF(0, 0)
157
+
158
+ def get_bounding_box_bottom_right(self):
159
+ """Get the bottom-right corner of the annotation's bounding box."""
160
+ height, width = self.mask_data.shape
161
+ return QPointF(width, height)
162
+
163
+ def create_graphics_item(self, scene: QGraphicsScene):
164
+ """Create a QGraphicsPixmapItem to display the mask."""
165
+ pixmap = self._render_mask_to_pixmap()
166
+ self.graphics_item = QGraphicsPixmapItem(pixmap)
167
+ self.graphics_item.setPos(self.offset)
168
+ super().create_graphics_item(scene)
169
+
170
+ def update_graphics_item(self):
171
+ """Update the pixmap if the mask data has changed."""
172
+ if self.graphics_item:
173
+ pixmap = self._render_mask_to_pixmap()
174
+ self.graphics_item.setPixmap(pixmap)
175
+ super().update_graphics_item()
176
+
177
+ def update_mask(self, brush_location: QPointF, brush_mask: np.ndarray, new_class_id: int):
178
+ """Modify the mask data based on a brush stroke."""
179
+ x_start, y_start = int(brush_location.x()), int(brush_location.y())
180
+ brush_h, brush_w = brush_mask.shape
181
+ mask_h, mask_w = self.mask_data.shape
182
+
183
+ x_end = min(x_start + brush_w, mask_w)
184
+ y_end = min(y_start + brush_h, mask_h)
185
+ clipped_x_start = max(x_start, 0)
186
+ clipped_y_start = max(y_start, 0)
187
+
188
+ if clipped_x_start >= x_end or clipped_y_start >= y_end:
189
+ return
190
+
191
+ target_slice = self.mask_data[clipped_y_start:y_end, clipped_x_start:x_end]
192
+ brush_x_offset = clipped_x_start - x_start
193
+ brush_y_offset = clipped_y_start - y_start
194
+ clipped_brush_mask = brush_mask[brush_y_offset:brush_y_offset + target_slice.shape[0],
195
+ brush_x_offset:brush_x_offset + target_slice.shape[1]]
196
+
197
+ target_slice[clipped_brush_mask] = new_class_id
198
+
199
+ self.update_graphics_item()
200
+ self.annotationUpdated.emit(self)
201
+
202
+ # --- Data Manipulation & Editing Methods ---
203
+
204
+ def fill_region(self, point: QPointF, new_class_id: int):
205
+ """Fills a contiguous region with a new class ID (paint bucket tool)."""
206
+ x, y = int(point.x()), int(point.y())
207
+ height, width = self.mask_data.shape
208
+ if not (0 <= y < height and 0 <= x < width):
209
+ return
210
+
211
+ old_class_id = self.mask_data[y, x]
212
+ if old_class_id == new_class_id:
213
+ return
214
+
215
+ labeled_array, num_features = ndimage_label(self.mask_data == old_class_id)
216
+ region_label = labeled_array[y, x]
217
+
218
+ self.mask_data[labeled_array == region_label] = new_class_id
219
+
220
+ self.update_graphics_item()
221
+ self.annotationUpdated.emit(self)
222
+
223
+ def replace_class(self, old_class_id: int, new_class_id: int):
224
+ """Replaces all pixels of one class ID with another across the entire mask."""
225
+ self.mask_data[self.mask_data == old_class_id] = new_class_id
226
+ self.update_graphics_item()
227
+ self.annotationUpdated.emit(self)
228
+
229
+ def import_annotations(self, annotations: list, class_id: int):
230
+ """Burns a list of vector annotations into the mask with a given class ID."""
231
+ height, width = self.mask_data.shape
232
+ # Use Format_Alpha8 for an efficient 8-bit stencil mask
233
+ stencil = QImage(width, height, QImage.Format_Alpha8)
234
+ stencil.fill(Qt.transparent)
235
+
236
+ painter = QPainter(stencil)
237
+ painter.setPen(Qt.NoPen)
238
+ painter.setBrush(QBrush(QColor("white"))) # Opaque white
239
+
240
+ for anno in annotations:
241
+ path = anno.get_painter_path()
242
+ painter.drawPath(path)
243
+ painter.end()
244
+
245
+ # Convert the QImage stencil to a boolean numpy array
246
+ ptr = stencil.bits()
247
+ ptr.setsize(stencil.byteCount())
248
+ stencil_np = np.array(ptr).reshape(height, width) > 0 # True where painted
249
+
250
+ # Update the mask data where the stencil is True
251
+ self.mask_data[stencil_np] = class_id
252
+
253
+ self.update_graphics_item()
254
+ self.annotationUpdated.emit(self)
255
+
256
+ # --- Analysis & Information Retrieval Methods ---
257
+
258
+ def get_class_statistics(self) -> dict:
259
+ """Returns a dictionary with pixel counts and percentages for each class."""
260
+ stats = {}
261
+ total_pixels = self.get_area()
262
+ if total_pixels == 0:
263
+ return stats
264
+
265
+ class_ids, counts = np.unique(self.mask_data[self.mask_data > 0], return_counts=True)
266
+
267
+ for cid, count in zip(class_ids, counts):
268
+ label = self.label_map.get(int(cid))
269
+ if label:
270
+ stats[label.short_label_code] = {
271
+ "pixel_count": int(count),
272
+ "percentage": (count / total_pixels) * 100
273
+ }
274
+ return stats
275
+
276
+ def get_class_at_point(self, point: QPointF) -> int:
277
+ """Returns the class ID at a specific point."""
278
+ x, y = int(point.x()), int(point.y())
279
+ height, width = self.mask_data.shape
280
+ if 0 <= y < height and 0 <= x < width:
281
+ return self.mask_data[y, x]
282
+ return 0 # Return background class if outside bounds
283
+
284
+ # --- Conversion & Exporting Methods ---
285
+
286
+ def get_binary_mask(self, class_id: int) -> np.ndarray:
287
+ """Returns a boolean numpy array where True corresponds to the given class ID."""
288
+ return self.mask_data == class_id
289
+
290
+ def to_instance_polygons(self, class_id: int) -> list:
291
+ """Converts all contiguous regions of a class ID into PolygonAnnotations."""
292
+ binary_mask = self.get_binary_mask(class_id)
293
+ # Add padding to handle contours touching the border
294
+ padded_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
295
+
296
+ # Level is 0.5 to find contours between 0 and 1 values
297
+ contours = find_contours(padded_mask, level=0.5)
298
+
299
+ annotations = []
300
+ for contour in contours:
301
+ # Remove padding offset and swap (row, col) to (x, y)
302
+ points = [QPointF(p[1] - 1, p[0] - 1) for p in contour]
303
+ if len(points) > 2: # Must have at least 3 points for a valid polygon
304
+ # Use the label associated with the class_id for the new annotation
305
+ label = self.label_map[class_id]
306
+ anno = PolygonAnnotation(
307
+ points=points,
308
+ short_label_code=label.short_label_code,
309
+ long_label_code=label.long_label_code,
310
+ color=label.color,
311
+ image_path=self.image_path,
312
+ label_id=label.id
313
+ )
314
+ annotations.append(anno)
315
+ return annotations
316
+
317
+ def export_as_png(self, path: str, use_label_colors: bool = True):
318
+ """Saves the mask to a PNG file."""
319
+ if use_label_colors:
320
+ # Use rendering logic to create a colored image
321
+ pixmap = self._render_mask_to_pixmap()
322
+ pixmap.toImage().save(path)
323
+ else:
324
+ # Save the raw class IDs as a grayscale image
325
+ height, width = self.mask_data.shape
326
+ # Ensure data is in a format QImage can handle (e.g., 8-bit grayscale)
327
+ if self.mask_data.max() < 256:
328
+ img_data = self.mask_data.astype(np.uint8)
329
+ q_image = QImage(img_data.data, width, height, QImage.Format_Grayscale8)
330
+ q_image.save(path)
331
+ else:
332
+ warnings.warn("Mask contains class IDs > 255; cannot save as 8-bit grayscale PNG.")
333
+
334
+ def export_as_raster(self, path: str):
335
+ """Saves the mask data to a raster file (e.g., GeoTIFF) using rasterio."""
336
+ profile = {
337
+ 'driver': 'GTiff',
338
+ 'height': self.mask_data.shape[0],
339
+ 'width': self.mask_data.shape[1],
340
+ 'count': 1,
341
+ 'dtype': self.mask_data.dtype
342
+ }
343
+
344
+ # If the original image was opened with rasterio, copy its spatial metadata
345
+ if self.rasterio_src:
346
+ profile['crs'] = self.rasterio_src.crs
347
+ profile['transform'] = self.rasterio_src.transform
348
+
349
+ with rasterio.open(path, 'w', **profile) as dst:
350
+ dst.write(self.mask_data, 1)
351
+
352
+ # --- Serialization & Deserialization ---
353
+
354
+ def to_dict(self):
355
+ """Serialize the annotation to a dictionary, with RLE for the mask."""
356
+ base_dict = super().to_dict()
357
+ rle_string = rle_encode(self.mask_data)
358
+ base_dict.update({
359
+ 'shape': self.mask_data.shape,
360
+ 'rle_mask': rle_string,
361
+ 'label_map': {cid: label.short_label_code for cid, label in self.label_map.items()}
362
+ })
363
+ return base_dict
364
+
365
+ @classmethod
366
+ def from_dict(cls, data, label_window):
367
+ """Instantiate a MaskAnnotation from a dictionary."""
368
+ label_map = {}
369
+ for cid_str, short_code in data['label_map'].items():
370
+ label = label_window.get_label_by_short_code(short_code)
371
+ if label:
372
+ label_map[int(cid_str)] = label
373
+
374
+ mask_data = rle_decode(data['rle_mask'], data['shape'])
375
+
376
+ annotation = cls(
377
+ image_path=data['image_path'],
378
+ mask_data=mask_data,
379
+ label_map=label_map
380
+ )
381
+ annotation.id = data.get('id', annotation.id)
382
+ annotation.data = data.get('data', {})
383
+ return annotation
384
+
385
+ @classmethod
386
+ def from_rasterio(cls, file_path: str, image_path: str, label_map: dict):
387
+ """Creates a MaskAnnotation instance by loading data from a raster file."""
388
+ with rasterio.open(file_path) as src:
389
+ mask_data = src.read(1)
390
+ return cls(
391
+ image_path=image_path,
392
+ mask_data=mask_data,
393
+ label_map=label_map,
394
+ rasterio_src=src
395
+ )
396
+
397
+ # --- Compatibility Methods ---
398
+ def get_perimeter(self):
399
+ height, width = self.mask_data.shape
400
+ return 2 * (width + height)
401
+
402
+ def get_polygon(self):
403
+ height, width = self.mask_data.shape
404
+ return QPolygonF(QRectF(0, 0, width, height))
405
+
406
+ def __repr__(self):
407
+ return (f"MaskAnnotation(id={self.id}, image_path={self.image_path}, "
408
+ f"shape={self.mask_data.shape})")