coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
  2. coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
  3. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
  4. coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
  5. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
  6. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
  7. coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
  8. coralnet_toolbox/CoralNet/QtDownload.py +2 -1
  9. coralnet_toolbox/Explorer/QtDataItem.py +1 -1
  10. coralnet_toolbox/Explorer/QtExplorer.py +159 -17
  11. coralnet_toolbox/Explorer/QtSettingsWidgets.py +160 -86
  12. coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
  13. coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
  14. coralnet_toolbox/IO/QtOpenProject.py +46 -78
  15. coralnet_toolbox/IO/QtSaveProject.py +18 -43
  16. coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
  17. coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
  18. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
  19. coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
  20. coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
  21. coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
  22. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
  23. coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
  24. coralnet_toolbox/QtAnnotationWindow.py +42 -14
  25. coralnet_toolbox/QtEventFilter.py +19 -2
  26. coralnet_toolbox/QtImageWindow.py +134 -86
  27. coralnet_toolbox/QtLabelWindow.py +14 -2
  28. coralnet_toolbox/QtMainWindow.py +122 -9
  29. coralnet_toolbox/QtProgressBar.py +52 -27
  30. coralnet_toolbox/Rasters/QtRaster.py +59 -7
  31. coralnet_toolbox/Rasters/RasterTableModel.py +42 -14
  32. coralnet_toolbox/SAM/QtBatchInference.py +0 -2
  33. coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
  34. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  35. coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
  36. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1634 -0
  37. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +107 -154
  38. coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
  39. coralnet_toolbox/SeeAnything/__init__.py +2 -0
  40. coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
  41. coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
  42. coralnet_toolbox/Tools/QtSAMTool.py +222 -57
  43. coralnet_toolbox/Tools/QtSeeAnythingTool.py +223 -55
  44. coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
  45. coralnet_toolbox/Tools/QtSelectTool.py +27 -3
  46. coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
  47. coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
  48. coralnet_toolbox/Tools/__init__.py +2 -0
  49. coralnet_toolbox/__init__.py +1 -1
  50. coralnet_toolbox/utilities.py +137 -47
  51. coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
  52. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +56 -53
  53. coralnet_toolbox-0.0.72.dist-info/METADATA +0 -341
  54. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
  55. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
  56. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
  57. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -197,30 +197,50 @@ class ExportTagLabAnnotations:
197
197
  """
198
198
  # Calculate bounding box, centroid, area, perimeter, and contour
199
199
  points = annotation.points
200
+ # Convert points to TagLab contour format
201
+ contour = self.taglabToPoints(np.array([[point.x(), point.y()] for point in points]))
202
+ inner_contours = []
203
+ if hasattr(annotation, 'holes') and annotation.holes:
204
+ # Convert holes to TagLab format
205
+ for hole in annotation.holes:
206
+ inner_contours.append(self.taglabToPoints(np.array([[point.x(), point.y()] for point in hole])))
207
+
208
+ # Calculate bounding box
200
209
  min_x = int(min(point.x() for point in points))
201
210
  min_y = int(min(point.y() for point in points))
202
211
  max_x = int(max(point.x() for point in points))
203
- max_y = int(max(point.y() for point in points))
212
+ max_y = int(max(point.y() for point in points))
204
213
  width = max_x - min_x
205
214
  height = max_y - min_y
215
+ bbox = [min_x, min_y, width, height]
216
+
217
+ # Calculate centroid
206
218
  centroid_x = float(f"{sum(point.x() for point in points) / len(points):.1f}")
207
219
  centroid_y = float(f"{sum(point.y() for point in points) / len(points):.1f}")
220
+ centroid = [centroid_x, centroid_y]
221
+
208
222
  area = float(f"{annotation.get_area():.1f}")
209
223
  perimeter = float(f"{annotation.get_perimeter():.1f}")
210
- contour = self.taglabToPoints(np.array([[point.x(), point.y()] for point in points]))
224
+ data = annotation.data if hasattr(annotation, 'data') else {}
225
+
226
+ # Pop these keys from data if they exist
227
+ class_name = data.pop('class_name', annotation.label.short_label_code)
228
+ instance_name = data.pop('instance_name', "coral0")
229
+ blob_name = data.pop('blob_name', f"c-0-{centroid_x}x-{centroid_y}y")
230
+ note = data.pop('note', "")
211
231
 
212
232
  annotation_dict = {
213
- "bbox": [min_y, min_x, width, height],
214
- "centroid": [centroid_x, centroid_y],
233
+ "bbox": bbox,
234
+ "centroid": centroid,
215
235
  "area": area,
216
236
  "perimeter": perimeter,
217
237
  "contour": contour,
218
- "inner contours": [],
219
- "class name": annotation.label.short_label_code,
220
- "instance name": "coral0", # Placeholder, update as needed
221
- "blob name": f"c-0-{centroid_x}x-{centroid_y}y",
222
- "note": "",
223
- "data": {}
238
+ "inner contours": inner_contours,
239
+ "class name": class_name,
240
+ "instance name": instance_name,
241
+ "blob name": blob_name,
242
+ "note": note,
243
+ "data": data
224
244
  }
225
245
 
226
246
  return annotation_dict
@@ -161,23 +161,25 @@ class ImportTagLabAnnotations:
161
161
  short_label_code = label_info['name'].strip()
162
162
  long_label_code = label_info['name'].strip()
163
163
  color = QColor(*label_info['fill'])
164
-
165
- # Unpack the annotation data
166
- bbox = annotation['bbox']
167
- centroid = annotation['centroid']
168
- area = annotation['area']
169
- perimeter = annotation['perimeter']
170
- contour = annotation['contour']
171
- inner_contours = annotation['inner contours']
172
- class_name = annotation['class name']
173
- instance_name = annotation['instance name']
174
- blob_name = annotation['blob name']
175
- idx = annotation['id']
176
- note = annotation['note']
177
- data = annotation['data']
164
+
165
+ # Pack all other data into a dict
166
+ imported_data = {
167
+ 'bbox': annotation.get('bbox'),
168
+ 'centroid': annotation.get('centroid'),
169
+ 'area': annotation.get('area'),
170
+ 'perimeter': annotation.get('perimeter'),
171
+ 'class_name': annotation.get('class name'),
172
+ 'instance_name': annotation.get('instance name'),
173
+ 'blob_name': annotation.get('blob name'),
174
+ 'id': annotation.get('id'),
175
+ 'note': annotation.get('note'),
176
+ 'data': annotation.get('data'),
177
+ }
178
178
 
179
179
  # Convert contour string to points
180
180
  points = self.parse_contour(annotation['contour'])
181
+ # Convert inner contours to a list of lists of points (holes)
182
+ holes = [self.parse_contour(inner) for inner in annotation.get('inner contours', [])]
181
183
 
182
184
  # Create the label if it doesn't exist
183
185
  label = self.label_window.add_label_if_not_exists(short_label_code,
@@ -191,8 +193,12 @@ class ImportTagLabAnnotations:
191
193
  long_label_code=long_label_code,
192
194
  color=color,
193
195
  image_path=image_full_path,
194
- label_id=label_id
196
+ label_id=label_id,
197
+ holes=holes,
195
198
  )
199
+ # Add additional data to the annotation
200
+ polygon_annotation.data = imported_data
201
+
196
202
  # Add annotation to the dict
197
203
  self.annotation_window.add_annotation_to_dict(polygon_annotation)
198
204
 
@@ -12,16 +12,12 @@ from PyQt5.QtWidgets import (QDialog, QFileDialog, QVBoxLayout, QPushButton, QLa
12
12
  QLineEdit)
13
13
 
14
14
  from coralnet_toolbox.QtLabelWindow import Label
15
-
16
15
  from coralnet_toolbox.QtWorkArea import WorkArea
17
-
18
16
  from coralnet_toolbox.Annotations.QtPatchAnnotation import PatchAnnotation
19
17
  from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
20
18
  from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
21
19
  from coralnet_toolbox.Annotations.QtMultiPolygonAnnotation import MultiPolygonAnnotation
22
-
23
20
  from coralnet_toolbox.Common.QtUpdateImagePaths import UpdateImagePaths
24
-
25
21
  from coralnet_toolbox.QtProgressBar import ProgressBar
26
22
 
27
23
 
@@ -143,19 +139,18 @@ class OpenProject(QDialog):
143
139
  with open(file_path, 'r') as file:
144
140
  project_data = json.load(file)
145
141
 
142
+ # Handle both new and old project formats for images and work areas
143
+ images_data = project_data.get('images', project_data.get('image_paths'))
144
+ legacy_workareas = project_data.get('workareas') # For backward compatibility
145
+
146
146
  # Update main window with loaded project data
147
- self.import_images(project_data.get('image_paths'))
148
- self.import_workareas(project_data.get('workareas'))
147
+ self.import_images(images_data, legacy_workareas)
149
148
  self.import_labels(project_data.get('labels'))
150
149
  self.import_annotations(project_data.get('annotations'))
151
150
 
152
151
  # Update current project path
153
152
  self.current_project_path = file_path
154
153
 
155
- QMessageBox.information(self.annotation_window,
156
- "Project Loaded",
157
- "Project has been successfully loaded.")
158
-
159
154
  except Exception as e:
160
155
  QMessageBox.warning(self.annotation_window,
161
156
  "Error Loading Project",
@@ -168,10 +163,15 @@ class OpenProject(QDialog):
168
163
  # Exit
169
164
  self.accept()
170
165
 
171
- def import_images(self, image_paths):
172
- """Import images from the given paths."""
173
- if not image_paths:
166
+ def import_images(self, images_data, legacy_workareas=None):
167
+ """Import images, states, and work areas from the given data."""
168
+ if not images_data:
174
169
  return
170
+
171
+ # Determine if the format is old (list of strings) or new (list of dicts)
172
+ is_new_format = isinstance(images_data[0], dict)
173
+
174
+ image_paths = [img['path'] for img in images_data] if is_new_format else images_data
175
175
 
176
176
  if not all([os.path.exists(path) for path in image_paths]):
177
177
  image_paths, self.updated_paths = UpdateImagePaths.update_paths(image_paths)
@@ -183,15 +183,46 @@ class OpenProject(QDialog):
183
183
  progress_bar.start_progress(total_images)
184
184
 
185
185
  try:
186
+ # Create a map for quick data lookup if using the new format
187
+ image_data_map = {img['path']: img for img in images_data} if is_new_format else {}
188
+
186
189
  # Add images to the image window's raster manager one by one
187
190
  for path in image_paths:
188
- # Use the improved add_image method which handles both
189
- # adding to raster_manager and updating filtered_paths
190
191
  self.image_window.add_image(path)
192
+ raster = self.image_window.raster_manager.get_raster(path)
193
+ if not raster:
194
+ continue
195
+
196
+ # If using the new format, apply saved state and work areas
197
+ if is_new_format and path in image_data_map:
198
+ data = image_data_map[path]
199
+ state = data.get('state', {})
200
+ work_areas_list = data.get('work_areas', [])
201
+
202
+ # Apply raster state
203
+ raster.checkbox_state = state.get('checkbox_state', False)
204
+
205
+ # Import work areas for this image
206
+ for work_area_data in work_areas_list:
207
+ try:
208
+ work_area = WorkArea.from_dict(work_area_data, path)
209
+ raster.add_work_area(work_area)
210
+ except Exception as e:
211
+ print(f"Warning: Could not import work area {work_area_data}: {str(e)}")
191
212
 
192
213
  # Update the progress bar
193
214
  progress_bar.update_progress()
194
215
 
216
+ # Handle backward compatibility for old, top-level work areas
217
+ if legacy_workareas:
218
+ for image_path, work_areas_list in legacy_workareas.items():
219
+ current_path = self.updated_paths.get(image_path, image_path)
220
+ raster = self.image_window.raster_manager.get_raster(current_path)
221
+ if raster:
222
+ for work_area_data in work_areas_list:
223
+ work_area = WorkArea.from_dict(work_area_data, current_path)
224
+ raster.add_work_area(work_area)
225
+
195
226
  # Show the last image if any were imported
196
227
  if self.image_window.raster_manager.image_paths:
197
228
  self.image_window.load_image_by_path(self.image_window.raster_manager.image_paths[-1])
@@ -204,69 +235,6 @@ class OpenProject(QDialog):
204
235
  # Close progress bar
205
236
  progress_bar.stop_progress()
206
237
  progress_bar.close()
207
-
208
- def import_workareas(self, workareas):
209
- """Import work areas for each image."""
210
- if not workareas:
211
- return
212
-
213
- # Start the progress bar
214
- total_images = len(workareas)
215
- progress_bar = ProgressBar(self.annotation_window, title="Importing Work Areas")
216
- progress_bar.show()
217
- progress_bar.start_progress(total_images)
218
-
219
- try:
220
- # Loop through each image's work areas
221
- for image_path, work_areas_list in workareas.items():
222
-
223
- # Check if the image path was updated (moved)
224
- updated_path = False
225
-
226
- if image_path not in self.image_window.raster_manager.image_paths:
227
- # Check if the path was updated
228
- if image_path in self.updated_paths:
229
- image_path = self.updated_paths[image_path]
230
- updated_path = True
231
- else:
232
- print(f"Warning: Image not found for work areas: {image_path}")
233
- continue
234
-
235
- # Get the raster for this image
236
- raster = self.image_window.raster_manager.get_raster(image_path)
237
- if not raster:
238
- print(f"Warning: Could not get raster for image: {image_path}")
239
- continue
240
-
241
- # Import each work area for this image
242
- for work_area_data in work_areas_list:
243
- try:
244
- # Update image path if it was changed
245
- if updated_path:
246
- work_area_data['image_path'] = image_path
247
-
248
- # Create WorkArea from dictionary
249
- work_area = WorkArea.from_dict(work_area_data, image_path)
250
-
251
- # Add work area to the raster
252
- raster.add_work_area(work_area)
253
-
254
- except Exception as e:
255
- print(f"Warning: Could not import work area {work_area_data}: {str(e)}")
256
- continue
257
-
258
- # Update the progress bar
259
- progress_bar.update_progress()
260
-
261
- except Exception as e:
262
- QMessageBox.warning(self.annotation_window,
263
- "Error Importing Work Areas",
264
- f"An error occurred while importing work areas: {str(e)}")
265
-
266
- finally:
267
- # Close progress bar
268
- progress_bar.stop_progress()
269
- progress_bar.close()
270
238
 
271
239
  def import_labels(self, labels):
272
240
  """Import labels from the given list."""
@@ -96,10 +96,9 @@ class SaveProject(QDialog):
96
96
 
97
97
  try:
98
98
  project_data = {
99
- 'image_paths': self.get_images(),
99
+ 'images': self.get_images(),
100
100
  'labels': self.get_labels(),
101
- 'annotations': self.get_annotations(),
102
- 'workareas': self.get_workareas()
101
+ 'annotations': self.get_annotations()
103
102
  }
104
103
 
105
104
  with open(file_path, 'w') as file:
@@ -125,10 +124,10 @@ class SaveProject(QDialog):
125
124
  self.accept()
126
125
 
127
126
  def get_images(self):
128
- """Get the list of image paths to export."""
127
+ """Get the list of image objects, including paths, states, and work areas."""
129
128
  # Start the progress bar
130
129
  total_images = len(self.image_window.raster_manager.image_paths)
131
- progress_bar = ProgressBar(self.label_window, "Exporting Images")
130
+ progress_bar = ProgressBar(self.label_window, "Exporting Image Data")
132
131
  progress_bar.show()
133
132
  progress_bar.start_progress(total_images)
134
133
 
@@ -137,7 +136,19 @@ class SaveProject(QDialog):
137
136
 
138
137
  # Loop through all of the image paths
139
138
  for image_path in self.image_window.raster_manager.image_paths:
140
- export_images.append(image_path)
139
+ raster = self.image_window.raster_manager.get_raster(image_path)
140
+ if raster:
141
+ # Get work areas for this raster
142
+ work_areas_list = [wa.to_dict() for wa in raster.get_work_areas()]
143
+
144
+ image_data = {
145
+ 'path': image_path,
146
+ 'state': {
147
+ 'checkbox_state': raster.checkbox_state
148
+ },
149
+ 'work_areas': work_areas_list
150
+ }
151
+ export_images.append(image_data)
141
152
  progress_bar.update_progress()
142
153
 
143
154
  except Exception as e:
@@ -234,42 +245,6 @@ class SaveProject(QDialog):
234
245
 
235
246
  return export_annotations
236
247
 
237
- def get_workareas(self):
238
- """Get the work areas to export."""
239
- # Start progress bar
240
- total_rasters = len(self.image_window.raster_manager.image_paths)
241
- progress_bar = ProgressBar(self.annotation_window, title="Exporting Work Areas")
242
- progress_bar.show()
243
- progress_bar.start_progress(total_rasters)
244
-
245
- try:
246
- export_workareas = {}
247
-
248
- # Loop through all rasters to get their work areas
249
- for image_path in self.image_window.raster_manager.image_paths:
250
- raster = self.image_window.raster_manager.get_raster(image_path)
251
- if raster and raster.has_work_areas():
252
- work_areas_list = []
253
- for work_area in raster.get_work_areas():
254
- work_areas_list.append(work_area.to_dict())
255
-
256
- if work_areas_list: # Only add if there are work areas
257
- export_workareas[image_path] = work_areas_list
258
-
259
- progress_bar.update_progress()
260
-
261
- except Exception as e:
262
- QMessageBox.warning(self.annotation_window,
263
- "Error Exporting Work Areas",
264
- f"An error occurred while exporting work areas: {str(e)}")
265
-
266
- finally:
267
- # Stop the progress bar
268
- progress_bar.stop_progress()
269
- progress_bar.close()
270
-
271
- return export_workareas
272
-
273
248
  def get_project_path(self):
274
249
  """Get the current project path."""
275
250
  return self.current_project_path
@@ -290,4 +265,4 @@ class SaveProject(QDialog):
290
265
  """Handle dialog rejection (Cancel or close)"""
291
266
  if self.current_project_path:
292
267
  self.file_path_edit.setText(self.current_project_path)
293
- super().reject()
268
+ super().reject()
@@ -123,18 +123,29 @@ class Detect(Base):
123
123
 
124
124
  def update_sam_task_state(self):
125
125
  """
126
- Centralized method to check if SAM is loaded and update task and dropdown accordingly.
126
+ Centralized method to check if SAM is loaded and update task accordingly.
127
+ If the user has selected to use SAM, this function ensures the task is set to 'segment'.
128
+ Crucially, it does NOT alter the task if SAM is not selected, respecting the
129
+ user's choice from the 'Task' dropdown.
127
130
  """
128
- sam_active = (
129
- self.sam_dialog is not None and
130
- self.sam_dialog.loaded_model is not None and
131
- self.use_sam_dropdown.currentText() == "True"
132
- )
133
- if sam_active:
134
- self.task = 'segment'
135
- else:
136
- self.task = 'detect'
137
- self.use_sam_dropdown.setCurrentText("False")
131
+ # Check if the user wants to use the SAM model
132
+ if self.use_sam_dropdown.currentText() == "True":
133
+ # SAM is requested. Check if it's actually available.
134
+ sam_is_available = (
135
+ hasattr(self, 'sam_dialog') and
136
+ self.sam_dialog is not None and
137
+ self.sam_dialog.loaded_model is not None
138
+ )
139
+
140
+ if sam_is_available:
141
+ # If SAM is wanted and available, the task must be segmentation.
142
+ self.task = 'segment'
143
+ else:
144
+ # If SAM is wanted but not available, revert the dropdown and do nothing else.
145
+ # The 'is_sam_model_deployed' function already handles showing an error message.
146
+ self.use_sam_dropdown.setCurrentText("False")
147
+
148
+ # If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
138
149
 
139
150
  def load_model(self):
140
151
  """
@@ -123,17 +123,29 @@ class Segment(Base):
123
123
 
124
124
  def update_sam_task_state(self):
125
125
  """
126
- Centralized method to check if SAM is loaded and update task and dropdown accordingly.
126
+ Centralized method to check if SAM is loaded and update task accordingly.
127
+ If the user has selected to use SAM, this function ensures the task is set to 'segment'.
128
+ Crucially, it does NOT alter the task if SAM is not selected, respecting the
129
+ user's choice from the 'Task' dropdown.
127
130
  """
128
- sam_active = (
129
- self.sam_dialog is not None and
130
- self.sam_dialog.loaded_model is not None and
131
- self.use_sam_dropdown.currentText() == "True"
132
- )
133
- if sam_active:
134
- self.task = 'segment'
135
- else:
136
- self.use_sam_dropdown.setCurrentText("False")
131
+ # Check if the user wants to use the SAM model
132
+ if self.use_sam_dropdown.currentText() == "True":
133
+ # SAM is requested. Check if it's actually available.
134
+ sam_is_available = (
135
+ hasattr(self, 'sam_dialog') and
136
+ self.sam_dialog is not None and
137
+ self.sam_dialog.loaded_model is not None
138
+ )
139
+
140
+ if sam_is_available:
141
+ # If SAM is wanted and available, the task must be segmentation.
142
+ self.task = 'segment'
143
+ else:
144
+ # If SAM is wanted but not available, revert the dropdown and do nothing else.
145
+ # The 'is_sam_model_deployed' function already handles showing an error message.
146
+ self.use_sam_dropdown.setCurrentText("False")
147
+
148
+ # If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
137
149
 
138
150
  def load_model(self):
139
151
  """
@@ -42,7 +42,7 @@ class Base(QDialog):
42
42
  self.annotation_window = main_window.annotation_window
43
43
  self.image_window = main_window.image_window
44
44
 
45
- self.resize(1000, 600)
45
+ self.resize(800, 800)
46
46
  self.setWindowIcon(get_icon("coral.png"))
47
47
  self.setWindowTitle("Export Dataset")
48
48
 
@@ -64,10 +64,8 @@ class Base(QDialog):
64
64
  self.setup_output_layout()
65
65
  # Setup the ratio layout
66
66
  self.setup_ratio_layout()
67
- # Setup the annotation layout
68
- self.setup_annotation_layout()
69
- # Setup the options layout
70
- self.setup_options_layout()
67
+ # Setup the data selection layout
68
+ self.setup_data_selection_layout()
71
69
  # Setup the table layout
72
70
  self.setup_table_layout()
73
71
  # Setup the status layout
@@ -147,10 +145,25 @@ class Base(QDialog):
147
145
  group_box.setLayout(layout)
148
146
  self.layout.addWidget(group_box)
149
147
 
150
- def setup_annotation_layout(self):
151
- """Setup the annotation type checkboxes layout."""
148
+ def setup_data_selection_layout(self):
149
+ """Setup the layout for data selection options in a horizontal arrangement."""
150
+ options_layout = QHBoxLayout()
151
+
152
+ # Create and add the group boxes
153
+ annotation_types_group = self.create_annotation_layout()
154
+ image_options_group = self.create_image_source_layout()
155
+ negative_samples_group = self.create_negative_samples_layout()
156
+
157
+ options_layout.addWidget(annotation_types_group)
158
+ options_layout.addWidget(image_options_group)
159
+ options_layout.addWidget(negative_samples_group)
160
+
161
+ self.layout.addLayout(options_layout)
162
+
163
+ def create_annotation_layout(self):
164
+ """Creates the annotation type checkboxes layout group box."""
152
165
  group_box = QGroupBox("Annotation Types")
153
- layout = QHBoxLayout()
166
+ layout = QVBoxLayout()
154
167
 
155
168
  self.include_patches_checkbox = QCheckBox("Include Patch Annotations")
156
169
  self.include_rectangles_checkbox = QCheckBox("Include Rectangle Annotations")
@@ -161,30 +174,24 @@ class Base(QDialog):
161
174
  layout.addWidget(self.include_polygons_checkbox)
162
175
 
163
176
  group_box.setLayout(layout)
164
- self.layout.addWidget(group_box)
177
+ return group_box
165
178
 
166
- def setup_options_layout(self):
167
- """Setup the image options layout."""
168
- group_box = QGroupBox("Image Options")
169
- layout = QHBoxLayout() # Changed from QVBoxLayout to QHBoxLayout
179
+ def create_image_source_layout(self):
180
+ """Creates the image source options layout group box."""
181
+ group_box = QGroupBox("Image Source")
182
+ layout = QVBoxLayout()
170
183
 
171
- # Create a button group for the image checkboxes
172
184
  self.image_options_group = QButtonGroup(self)
173
185
 
174
186
  self.all_images_radio = QRadioButton("All Images")
175
187
  self.filtered_images_radio = QRadioButton("Filtered Images")
176
188
 
177
- # Add the radio buttons to the button group
178
189
  self.image_options_group.addButton(self.all_images_radio)
179
190
  self.image_options_group.addButton(self.filtered_images_radio)
180
-
181
- # Ensure only one radio button can be checked at a time
182
191
  self.image_options_group.setExclusive(True)
183
192
 
184
- # Set the default radio button
185
193
  self.all_images_radio.setChecked(True)
186
194
 
187
- # Connect radio button signals
188
195
  self.all_images_radio.toggled.connect(self.update_image_selection)
189
196
  self.filtered_images_radio.toggled.connect(self.update_image_selection)
190
197
 
@@ -192,7 +199,32 @@ class Base(QDialog):
192
199
  layout.addWidget(self.filtered_images_radio)
193
200
 
194
201
  group_box.setLayout(layout)
195
- self.layout.addWidget(group_box)
202
+ return group_box
203
+
204
+ def create_negative_samples_layout(self):
205
+ """Creates the negative sample options layout group box."""
206
+ group_box = QGroupBox("Negative Samples")
207
+ layout = QVBoxLayout()
208
+
209
+ self.negative_samples_group = QButtonGroup(self)
210
+
211
+ self.include_negatives_radio = QRadioButton("Include Negatives")
212
+ self.exclude_negatives_radio = QRadioButton("Exclude Negatives")
213
+
214
+ self.negative_samples_group.addButton(self.include_negatives_radio)
215
+ self.negative_samples_group.addButton(self.exclude_negatives_radio)
216
+ self.negative_samples_group.setExclusive(True)
217
+
218
+ self.exclude_negatives_radio.setChecked(True)
219
+
220
+ # Connect to update stats when changed. Only one needed for the group.
221
+ self.include_negatives_radio.toggled.connect(self.update_summary_statistics)
222
+
223
+ layout.addWidget(self.include_negatives_radio)
224
+ layout.addWidget(self.exclude_negatives_radio)
225
+
226
+ group_box.setLayout(layout)
227
+ return group_box
196
228
 
197
229
  def setup_table_layout(self):
198
230
  """Setup the label counts table layout."""
@@ -424,6 +456,11 @@ class Base(QDialog):
424
456
  else:
425
457
  images = self.image_window.raster_manager.image_paths
426
458
 
459
+ # If "Exclude Negatives" is checked, only use images that have selected annotations.
460
+ if self.exclude_negatives_radio.isChecked():
461
+ image_paths_with_annotations = {a.image_path for a in self.selected_annotations}
462
+ images = [img for img in images if img in image_paths_with_annotations]
463
+
427
464
  random.shuffle(images)
428
465
 
429
466
  train_split = int(len(images) * self.train_ratio)
@@ -551,9 +588,6 @@ class Base(QDialog):
551
588
 
552
589
  self.updating_summary_statistics = True
553
590
 
554
- # Split the data by images
555
- self.split_data()
556
-
557
591
  # Selected labels based on user's selection
558
592
  self.selected_labels = []
559
593
  for row in range(self.label_counts_table.rowCount()):
@@ -564,6 +598,9 @@ class Base(QDialog):
564
598
 
565
599
  # Filter annotations based on the selected annotation types and current tab
566
600
  self.selected_annotations = self.filter_annotations()
601
+
602
+ # Split the data by images
603
+ self.split_data()
567
604
 
568
605
  # Split the data by annotations
569
606
  self.determine_splits()
@@ -704,4 +741,4 @@ class Base(QDialog):
704
741
  raise NotImplementedError("Method must be implemented in the subclass.")
705
742
 
706
743
  def process_annotations(self, annotations, split_dir, split):
707
- raise NotImplementedError("Method must be implemented in the subclass.")
744
+ raise NotImplementedError("Method must be implemented in the subclass.")
@@ -60,6 +60,10 @@ class Classify(Base):
60
60
  self.include_polygons_checkbox.setChecked(True)
61
61
  self.include_polygons_checkbox.setEnabled(True)
62
62
 
63
+ # Disable negative sample options for classification
64
+ self.include_negatives_radio.setEnabled(False)
65
+ self.exclude_negatives_radio.setEnabled(False)
66
+
63
67
  def create_dataset(self, output_dir_path):
64
68
  """
65
69
  Create an image classification dataset.
@@ -219,4 +223,4 @@ class Classify(Base):
219
223
  progress_bar.stop_progress()
220
224
  progress_bar.close()
221
225
  progress_bar = None
222
- gc.collect()
226
+ gc.collect()