coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.73__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
  2. coralnet_toolbox/Explorer/QtDataItem.py +1 -1
  3. coralnet_toolbox/Explorer/QtExplorer.py +143 -3
  4. coralnet_toolbox/Explorer/QtSettingsWidgets.py +46 -4
  5. coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
  6. coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
  7. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
  8. coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
  9. coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
  10. coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
  11. coralnet_toolbox/QtAnnotationWindow.py +42 -14
  12. coralnet_toolbox/QtEventFilter.py +8 -2
  13. coralnet_toolbox/QtImageWindow.py +17 -18
  14. coralnet_toolbox/QtLabelWindow.py +1 -1
  15. coralnet_toolbox/QtMainWindow.py +143 -8
  16. coralnet_toolbox/Rasters/QtRaster.py +59 -7
  17. coralnet_toolbox/Rasters/RasterTableModel.py +34 -6
  18. coralnet_toolbox/SAM/QtBatchInference.py +0 -2
  19. coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
  20. coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
  21. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1016 -0
  22. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +69 -53
  23. coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
  24. coralnet_toolbox/SeeAnything/__init__.py +2 -0
  25. coralnet_toolbox/Tools/QtSAMTool.py +150 -7
  26. coralnet_toolbox/Tools/QtSeeAnythingTool.py +220 -55
  27. coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
  28. coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
  29. coralnet_toolbox/__init__.py +1 -1
  30. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.73.dist-info}/METADATA +1 -1
  31. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.73.dist-info}/RECORD +35 -34
  32. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.73.dist-info}/WHEEL +0 -0
  33. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.73.dist-info}/entry_points.txt +0 -0
  34. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.73.dist-info}/licenses/LICENSE.txt +0 -0
  35. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.73.dist-info}/top_level.txt +0 -0
@@ -52,7 +52,7 @@ class DeployModelDialog(QDialog):
52
52
  self.annotation_window = main_window.annotation_window
53
53
 
54
54
  self.setWindowIcon(get_icon("coral.png"))
55
- self.setWindowTitle("AutoDistill Deploy Model (Ctrl + 5)")
55
+ self.setWindowTitle("AutoDistill Deploy Model (Ctrl + 6)")
56
56
  self.resize(400, 325)
57
57
 
58
58
  # Initialize variables
@@ -350,18 +350,29 @@ class DeployModelDialog(QDialog):
350
350
 
351
351
  def update_sam_task_state(self):
352
352
  """
353
- Centralized method to check if SAM is loaded and update task and dropdown accordingly.
353
+ Centralized method to check if SAM is loaded and update task accordingly.
354
+ If the user has selected to use SAM, this function ensures the task is set to 'segment'.
355
+ Crucially, it does NOT alter the task if SAM is not selected, respecting the
356
+ user's choice from the 'Task' dropdown.
354
357
  """
355
- sam_active = (
356
- self.sam_dialog is not None and
357
- self.sam_dialog.loaded_model is not None and
358
- self.use_sam_dropdown.currentText() == "True"
359
- )
360
- if sam_active:
361
- self.task = 'segment'
362
- else:
363
- self.task = 'detect'
364
- self.use_sam_dropdown.setCurrentText("False")
358
+ # Check if the user wants to use the SAM model
359
+ if self.use_sam_dropdown.currentText() == "True":
360
+ # SAM is requested. Check if it's actually available.
361
+ sam_is_available = (
362
+ hasattr(self, 'sam_dialog') and
363
+ self.sam_dialog is not None and
364
+ self.sam_dialog.loaded_model is not None
365
+ )
366
+
367
+ if sam_is_available:
368
+ # If SAM is wanted and available, the task must be segmentation.
369
+ self.task = 'segment'
370
+ else:
371
+ # If SAM is wanted but not available, revert the dropdown and do nothing else.
372
+ # The 'is_sam_model_deployed' function already handles showing an error message.
373
+ self.use_sam_dropdown.setCurrentText("False")
374
+
375
+ # If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
365
376
 
366
377
  def load_model(self):
367
378
  """
@@ -272,7 +272,7 @@ class AnnotationDataItem:
272
272
 
273
273
  self.embedding_x = embedding_x if embedding_x is not None else 0.0
274
274
  self.embedding_y = embedding_y if embedding_y is not None else 0.0
275
- self.embedding_id = embedding_id if embedding_id is not None else 0
275
+ self.embedding_id = embedding_id
276
276
 
277
277
  self._is_selected = False
278
278
  self._preview_label = None
@@ -16,7 +16,7 @@ from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollAre
16
16
  QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget,
17
17
  QMainWindow, QSplitter, QGroupBox, QSlider, QMessageBox,
18
18
  QApplication, QGraphicsRectItem, QRubberBand, QMenu,
19
- QWidgetAction, QToolButton, QAction)
19
+ QWidgetAction, QToolButton, QAction, QDoubleSpinBox)
20
20
 
21
21
  from coralnet_toolbox.Explorer.QtFeatureStore import FeatureStore
22
22
  from coralnet_toolbox.Explorer.QtDataItem import AnnotationDataItem
@@ -28,6 +28,7 @@ from coralnet_toolbox.Explorer.QtSettingsWidgets import UncertaintySettingsWidge
28
28
  from coralnet_toolbox.Explorer.QtSettingsWidgets import MislabelSettingsWidget
29
29
  from coralnet_toolbox.Explorer.QtSettingsWidgets import EmbeddingSettingsWidget
30
30
  from coralnet_toolbox.Explorer.QtSettingsWidgets import AnnotationSettingsWidget
31
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import DuplicateSettingsWidget
31
32
 
32
33
  from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
33
34
 
@@ -68,6 +69,8 @@ class EmbeddingViewer(QWidget):
68
69
  mislabel_parameters_changed = pyqtSignal(dict)
69
70
  find_uncertain_requested = pyqtSignal()
70
71
  uncertainty_parameters_changed = pyqtSignal(dict)
72
+ find_duplicates_requested = pyqtSignal()
73
+ duplicate_parameters_changed = pyqtSignal(dict)
71
74
 
72
75
  def __init__(self, parent=None):
73
76
  """Initialize the EmbeddingViewer widget."""
@@ -193,8 +196,37 @@ class EmbeddingViewer(QWidget):
193
196
 
194
197
  uncertainty_settings_widget.parameters_changed.connect(self.uncertainty_parameters_changed.emit)
195
198
  toolbar_layout.addWidget(self.find_uncertain_button)
199
+
200
+ # Create a QToolButton for duplicate detection
201
+ self.find_duplicates_button = QToolButton()
202
+ self.find_duplicates_button.setText("Find Duplicates")
203
+ self.find_duplicates_button.setToolTip(
204
+ "Find annotations that are likely duplicates based on feature similarity."
205
+ )
206
+ self.find_duplicates_button.setPopupMode(QToolButton.MenuButtonPopup)
207
+ self.find_duplicates_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
208
+ self.find_duplicates_button.setStyleSheet(
209
+ "QToolButton::menu-indicator { "
210
+ "subcontrol-position: right center; "
211
+ "subcontrol-origin: padding; "
212
+ "left: -4px; }"
213
+ )
214
+
215
+ run_duplicates_action = QAction("Find Duplicates", self)
216
+ run_duplicates_action.triggered.connect(self.find_duplicates_requested.emit)
217
+ self.find_duplicates_button.setDefaultAction(run_duplicates_action)
218
+
219
+ duplicate_settings_widget = DuplicateSettingsWidget()
220
+ duplicate_menu = QMenu(self)
221
+ duplicate_widget_action = QWidgetAction(duplicate_menu)
222
+ duplicate_widget_action.setDefaultWidget(duplicate_settings_widget)
223
+ duplicate_menu.addAction(duplicate_widget_action)
224
+ self.find_duplicates_button.setMenu(duplicate_menu)
225
+
226
+ duplicate_settings_widget.parameters_changed.connect(self.duplicate_parameters_changed.emit)
227
+ toolbar_layout.addWidget(self.find_duplicates_button)
196
228
 
197
- # Add a strech and separator
229
+ # Add a stretch and separator
198
230
  toolbar_layout.addStretch()
199
231
  toolbar_layout.addWidget(self._create_separator())
200
232
 
@@ -293,6 +325,7 @@ class EmbeddingViewer(QWidget):
293
325
 
294
326
  self.find_mislabels_button.setEnabled(points_exist)
295
327
  self.find_uncertain_button.setEnabled(points_exist and self.is_uncertainty_analysis_available)
328
+ self.find_duplicates_button.setEnabled(points_exist)
296
329
  self.center_on_selection_button.setEnabled(points_exist and selection_exists)
297
330
 
298
331
  if self.isolated_mode:
@@ -348,6 +381,7 @@ class EmbeddingViewer(QWidget):
348
381
  self.center_on_selection_button.setEnabled(False) # Disable center button
349
382
  self.find_mislabels_button.setEnabled(False)
350
383
  self.find_uncertain_button.setEnabled(False)
384
+ self.find_duplicates_button.setEnabled(False)
351
385
 
352
386
  self.isolate_button.show()
353
387
  self.isolate_button.setEnabled(False)
@@ -851,7 +885,7 @@ class AnnotationViewer(QWidget):
851
885
 
852
886
  # Show resize handles for Rectangle annotations
853
887
  if isinstance(annotation_to_select, RectangleAnnotation):
854
- explorer.annotation_window.set_selected_tool('select') # Accidently unselects in AnnotationWindow
888
+ explorer.annotation_window.set_selected_tool('select') # Accidentally unselects in AnnotationWindow
855
889
  explorer.annotation_window.select_annotation(annotation_to_select, quiet_mode=True)
856
890
  select_tool = explorer.annotation_window.tools.get('select')
857
891
 
@@ -1584,6 +1618,7 @@ class ExplorerWindow(QMainWindow):
1584
1618
  self.mislabel_params = {'k': 20, 'threshold': 0.6}
1585
1619
  self.uncertainty_params = {'confidence': 0.6, 'margin': 0.1}
1586
1620
  self.similarity_params = {'k': 30}
1621
+ self.duplicate_params = {'threshold': 0.05}
1587
1622
 
1588
1623
  self.data_item_cache = {} # Cache for AnnotationDataItem objects
1589
1624
 
@@ -1744,6 +1779,8 @@ class ExplorerWindow(QMainWindow):
1744
1779
  self.model_settings_widget.selection_changed.connect(self.on_model_selection_changed)
1745
1780
  self.embedding_viewer.find_uncertain_requested.connect(self.find_uncertain_annotations)
1746
1781
  self.embedding_viewer.uncertainty_parameters_changed.connect(self.on_uncertainty_params_changed)
1782
+ self.embedding_viewer.find_duplicates_requested.connect(self.find_duplicate_annotations)
1783
+ self.embedding_viewer.duplicate_parameters_changed.connect(self.on_duplicate_params_changed)
1747
1784
  self.annotation_viewer.find_similar_requested.connect(self.find_similar_annotations)
1748
1785
  self.annotation_viewer.similarity_settings_widget.parameters_changed.connect(self.on_similarity_params_changed)
1749
1786
 
@@ -1887,6 +1924,12 @@ class ExplorerWindow(QMainWindow):
1887
1924
  """Updates the stored parameters for uncertainty analysis."""
1888
1925
  self.uncertainty_params = params
1889
1926
  print(f"Uncertainty parameters updated: {self.uncertainty_params}")
1927
+
1928
+ @pyqtSlot(dict)
1929
+ def on_duplicate_params_changed(self, params):
1930
+ """Updates the stored parameters for duplicate detection."""
1931
+ self.duplicate_params = params
1932
+ print(f"Duplicate detection parameters updated: {self.duplicate_params}")
1890
1933
 
1891
1934
  @pyqtSlot(dict)
1892
1935
  def on_similarity_params_changed(self, params):
@@ -2067,6 +2110,98 @@ class ExplorerWindow(QMainWindow):
2067
2110
 
2068
2111
  finally:
2069
2112
  QApplication.restoreOverrideCursor()
2113
+
2114
+ def find_duplicate_annotations(self):
2115
+ """
2116
+ Identifies annotations that are likely duplicates based on feature similarity.
2117
+ It uses a nearest-neighbor approach in the high-dimensional feature space.
2118
+ For each group of duplicates found, it selects all but one "original".
2119
+ """
2120
+ threshold = self.duplicate_params.get('threshold', 0.05)
2121
+
2122
+ if not self.embedding_viewer.points_by_id or len(self.embedding_viewer.points_by_id) < 2:
2123
+ QMessageBox.information(self,
2124
+ "Not Enough Data",
2125
+ "This feature requires at least 2 points in the embedding viewer.")
2126
+ return
2127
+
2128
+ items_in_view = list(self.embedding_viewer.points_by_id.values())
2129
+ data_items_in_view = [p.data_item for p in items_in_view]
2130
+
2131
+ model_info = self.model_settings_widget.get_selected_model()
2132
+ model_name, feature_mode = model_info if isinstance(model_info, tuple) else (model_info, "default")
2133
+ sanitized_model_name = os.path.basename(model_name).replace(' ', '_')
2134
+ sanitized_feature_mode = feature_mode.replace(' ', '_').replace('/', '_')
2135
+ model_key = f"{sanitized_model_name}_{sanitized_feature_mode}"
2136
+
2137
+ # Make cursor busy
2138
+ QApplication.setOverrideCursor(Qt.WaitCursor)
2139
+ try:
2140
+ index = self.feature_store._get_or_load_index(model_key)
2141
+ if index is None:
2142
+ QMessageBox.warning(self, "Error", "Could not find a valid feature index for the current model.")
2143
+ return
2144
+
2145
+ features_dict, _ = self.feature_store.get_features(data_items_in_view, model_key)
2146
+ if not features_dict:
2147
+ QMessageBox.warning(self, "Error", "Could not retrieve features for the items in view.")
2148
+ return
2149
+
2150
+ query_ann_ids = list(features_dict.keys())
2151
+ query_vectors = np.array([features_dict[ann_id] for ann_id in query_ann_ids]).astype('float32')
2152
+
2153
+ # Find the 2 nearest neighbors for each vector. D = squared L2 distances.
2154
+ D, I = index.search(query_vectors, 2)
2155
+
2156
+ # Use a Disjoint Set Union (DSU) data structure to group duplicates.
2157
+ parent = {ann_id: ann_id for ann_id in query_ann_ids}
2158
+
2159
+ # Helper functions for DSU
2160
+ def find_set(v):
2161
+ if v == parent[v]:
2162
+ return v
2163
+ parent[v] = find_set(parent[v])
2164
+ return parent[v]
2165
+
2166
+ def unite_sets(a, b):
2167
+ a = find_set(a)
2168
+ b = find_set(b)
2169
+ if a != b:
2170
+ parent[b] = a
2171
+
2172
+ id_map = self.feature_store.get_faiss_index_to_annotation_id_map(model_key)
2173
+
2174
+ for i, ann_id in enumerate(query_ann_ids):
2175
+ neighbor_faiss_idx = I[i, 1] # The second result is the nearest neighbor
2176
+ distance = D[i, 1]
2177
+
2178
+ if distance < threshold:
2179
+ neighbor_ann_id = id_map.get(neighbor_faiss_idx)
2180
+ if neighbor_ann_id and neighbor_ann_id in parent:
2181
+ unite_sets(ann_id, neighbor_ann_id)
2182
+
2183
+ # Group annotations by their set representative
2184
+ groups = {}
2185
+ for ann_id in query_ann_ids:
2186
+ root = find_set(ann_id)
2187
+ if root not in groups:
2188
+ groups[root] = []
2189
+ groups[root].append(ann_id)
2190
+
2191
+ copies_to_select = set()
2192
+ for root_id, group_ids in groups.items():
2193
+ if len(group_ids) > 1:
2194
+ # Sort IDs to consistently pick the same "original".
2195
+ # Sorting strings is reliable.
2196
+ sorted_ids = sorted(group_ids)
2197
+ # The first ID is the original, add the rest to the selection.
2198
+ copies_to_select.update(sorted_ids[1:])
2199
+
2200
+ print(f"Found {len(copies_to_select)} duplicate annotations.")
2201
+ self.embedding_viewer.render_selection_from_ids(copies_to_select)
2202
+
2203
+ finally:
2204
+ QApplication.restoreOverrideCursor()
2070
2205
 
2071
2206
  def find_uncertain_annotations(self):
2072
2207
  """
@@ -2658,6 +2793,7 @@ class ExplorerWindow(QMainWindow):
2658
2793
  norm_y = (embedded_features[i, 1] - min_vals[1]) / range_vals[1] if range_vals[1] > 0 else 0.5
2659
2794
  item.embedding_x = (norm_x * scale_factor) - (scale_factor / 2)
2660
2795
  item.embedding_y = (norm_y * scale_factor) - (scale_factor / 2)
2796
+ item.embedding_id = i
2661
2797
 
2662
2798
  def run_embedding_pipeline(self):
2663
2799
  """
@@ -2813,6 +2949,10 @@ class ExplorerWindow(QMainWindow):
2813
2949
  self.current_data_items = [
2814
2950
  item for item in self.current_data_items if item.annotation.id not in deleted_ann_ids
2815
2951
  ]
2952
+ # Also update the annotation viewer's list to keep it in sync
2953
+ self.annotation_viewer.all_data_items = [
2954
+ item for item in self.annotation_viewer.all_data_items if item.annotation.id not in deleted_ann_ids
2955
+ ]
2816
2956
  for ann_id in deleted_ann_ids:
2817
2957
  if ann_id in self.data_item_cache:
2818
2958
  del self.data_item_cache[ann_id]
@@ -4,7 +4,7 @@ import warnings
4
4
  from PyQt5.QtCore import Qt, pyqtSignal, pyqtSlot
5
5
  from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QPushButton, QComboBox, QLabel,
6
6
  QWidget, QGroupBox, QSlider, QListWidget, QTabWidget,
7
- QLineEdit, QFileDialog, QFormLayout, QSpinBox)
7
+ QLineEdit, QFileDialog, QFormLayout, QSpinBox, QDoubleSpinBox)
8
8
 
9
9
  from coralnet_toolbox.MachineLearning.Community.cfg import get_available_configs
10
10
 
@@ -189,6 +189,48 @@ class SimilaritySettingsWidget(QWidget):
189
189
  'k': self.k_spinbox.value()
190
190
  }
191
191
 
192
+
193
+ class DuplicateSettingsWidget(QWidget):
194
+ """Widget for configuring duplicate detection parameters."""
195
+ parameters_changed = pyqtSignal(dict)
196
+
197
+ def __init__(self, parent=None):
198
+ super(DuplicateSettingsWidget, self).__init__(parent)
199
+ layout = QVBoxLayout(self)
200
+ layout.setContentsMargins(10, 10, 10, 10)
201
+
202
+ # Using a DoubleSpinBox for the distance threshold
203
+ self.threshold_spinbox = QDoubleSpinBox()
204
+ self.threshold_spinbox.setDecimals(3)
205
+ self.threshold_spinbox.setRange(0.0, 10.0)
206
+ self.threshold_spinbox.setSingleStep(0.01)
207
+ self.threshold_spinbox.setValue(0.1) # Default value for squared L2 distance
208
+ self.threshold_spinbox.setToolTip(
209
+ "Similarity Threshold (Squared L2 Distance).\n"
210
+ "Lower values mean more similar.\n"
211
+ "A value of 0 means identical features."
212
+ )
213
+
214
+ self.threshold_spinbox.valueChanged.connect(self._emit_parameters)
215
+
216
+ form_layout = QHBoxLayout()
217
+ form_layout.addWidget(QLabel("Threshold:"))
218
+ form_layout.addWidget(self.threshold_spinbox)
219
+ layout.addLayout(form_layout)
220
+
221
+ def _emit_parameters(self):
222
+ """Emits the current parameters."""
223
+ params = {
224
+ 'threshold': self.threshold_spinbox.value()
225
+ }
226
+ self.parameters_changed.emit(params)
227
+
228
+ def get_parameters(self):
229
+ """Returns the current parameters as a dictionary."""
230
+ return {
231
+ 'threshold': self.threshold_spinbox.value()
232
+ }
233
+
192
234
 
193
235
  class AnnotationSettingsWidget(QGroupBox):
194
236
  """Widget for filtering annotations by image, type, and label in a multi-column layout."""
@@ -213,7 +255,7 @@ class AnnotationSettingsWidget(QGroupBox):
213
255
  images_column.addWidget(images_label)
214
256
 
215
257
  self.images_list = QListWidget()
216
- self.images_list.setSelectionMode(QListWidget.MultiSelection)
258
+ self.images_list.setSelectionMode(QListWidget.ExtendedSelection)
217
259
  self.images_list.setMaximumHeight(50)
218
260
 
219
261
  if hasattr(self.main_window, 'image_window') and hasattr(self.main_window.image_window, 'raster_manager'):
@@ -241,7 +283,7 @@ class AnnotationSettingsWidget(QGroupBox):
241
283
  type_column.addWidget(type_label)
242
284
 
243
285
  self.annotation_type_list = QListWidget()
244
- self.annotation_type_list.setSelectionMode(QListWidget.MultiSelection)
286
+ self.annotation_type_list.setSelectionMode(QListWidget.ExtendedSelection)
245
287
  self.annotation_type_list.setMaximumHeight(50)
246
288
  self.annotation_type_list.addItems(["PatchAnnotation",
247
289
  "RectangleAnnotation",
@@ -269,7 +311,7 @@ class AnnotationSettingsWidget(QGroupBox):
269
311
  label_column.addWidget(label_label)
270
312
 
271
313
  self.label_list = QListWidget()
272
- self.label_list.setSelectionMode(QListWidget.MultiSelection)
314
+ self.label_list.setSelectionMode(QListWidget.ExtendedSelection)
273
315
  self.label_list.setMaximumHeight(50)
274
316
 
275
317
  if hasattr(self.main_window, 'label_window') and hasattr(self.main_window.label_window, 'labels'):
@@ -123,18 +123,29 @@ class Detect(Base):
123
123
 
124
124
  def update_sam_task_state(self):
125
125
  """
126
- Centralized method to check if SAM is loaded and update task and dropdown accordingly.
126
+ Centralized method to check if SAM is loaded and update task accordingly.
127
+ If the user has selected to use SAM, this function ensures the task is set to 'segment'.
128
+ Crucially, it does NOT alter the task if SAM is not selected, respecting the
129
+ user's choice from the 'Task' dropdown.
127
130
  """
128
- sam_active = (
129
- self.sam_dialog is not None and
130
- self.sam_dialog.loaded_model is not None and
131
- self.use_sam_dropdown.currentText() == "True"
132
- )
133
- if sam_active:
134
- self.task = 'segment'
135
- else:
136
- self.task = 'detect'
137
- self.use_sam_dropdown.setCurrentText("False")
131
+ # Check if the user wants to use the SAM model
132
+ if self.use_sam_dropdown.currentText() == "True":
133
+ # SAM is requested. Check if it's actually available.
134
+ sam_is_available = (
135
+ hasattr(self, 'sam_dialog') and
136
+ self.sam_dialog is not None and
137
+ self.sam_dialog.loaded_model is not None
138
+ )
139
+
140
+ if sam_is_available:
141
+ # If SAM is wanted and available, the task must be segmentation.
142
+ self.task = 'segment'
143
+ else:
144
+ # If SAM is wanted but not available, revert the dropdown and do nothing else.
145
+ # The 'is_sam_model_deployed' function already handles showing an error message.
146
+ self.use_sam_dropdown.setCurrentText("False")
147
+
148
+ # If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
138
149
 
139
150
  def load_model(self):
140
151
  """
@@ -123,17 +123,29 @@ class Segment(Base):
123
123
 
124
124
  def update_sam_task_state(self):
125
125
  """
126
- Centralized method to check if SAM is loaded and update task and dropdown accordingly.
126
+ Centralized method to check if SAM is loaded and update task accordingly.
127
+ If the user has selected to use SAM, this function ensures the task is set to 'segment'.
128
+ Crucially, it does NOT alter the task if SAM is not selected, respecting the
129
+ user's choice from the 'Task' dropdown.
127
130
  """
128
- sam_active = (
129
- self.sam_dialog is not None and
130
- self.sam_dialog.loaded_model is not None and
131
- self.use_sam_dropdown.currentText() == "True"
132
- )
133
- if sam_active:
134
- self.task = 'segment'
135
- else:
136
- self.use_sam_dropdown.setCurrentText("False")
131
+ # Check if the user wants to use the SAM model
132
+ if self.use_sam_dropdown.currentText() == "True":
133
+ # SAM is requested. Check if it's actually available.
134
+ sam_is_available = (
135
+ hasattr(self, 'sam_dialog') and
136
+ self.sam_dialog is not None and
137
+ self.sam_dialog.loaded_model is not None
138
+ )
139
+
140
+ if sam_is_available:
141
+ # If SAM is wanted and available, the task must be segmentation.
142
+ self.task = 'segment'
143
+ else:
144
+ # If SAM is wanted but not available, revert the dropdown and do nothing else.
145
+ # The 'is_sam_model_deployed' function already handles showing an error message.
146
+ self.use_sam_dropdown.setCurrentText("False")
147
+
148
+ # If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
137
149
 
138
150
  def load_model(self):
139
151
  """
@@ -42,7 +42,7 @@ class Base(QDialog):
42
42
  self.annotation_window = main_window.annotation_window
43
43
  self.image_window = main_window.image_window
44
44
 
45
- self.resize(1000, 600)
45
+ self.resize(1000, 800)
46
46
  self.setWindowIcon(get_icon("coral.png"))
47
47
  self.setWindowTitle("Export Dataset")
48
48
 
@@ -64,10 +64,8 @@ class Base(QDialog):
64
64
  self.setup_output_layout()
65
65
  # Setup the ratio layout
66
66
  self.setup_ratio_layout()
67
- # Setup the annotation layout
68
- self.setup_annotation_layout()
69
- # Setup the options layout
70
- self.setup_options_layout()
67
+ # Setup the data selection layout
68
+ self.setup_data_selection_layout()
71
69
  # Setup the table layout
72
70
  self.setup_table_layout()
73
71
  # Setup the status layout
@@ -147,10 +145,25 @@ class Base(QDialog):
147
145
  group_box.setLayout(layout)
148
146
  self.layout.addWidget(group_box)
149
147
 
150
- def setup_annotation_layout(self):
151
- """Setup the annotation type checkboxes layout."""
148
+ def setup_data_selection_layout(self):
149
+ """Setup the layout for data selection options in a horizontal arrangement."""
150
+ options_layout = QHBoxLayout()
151
+
152
+ # Create and add the group boxes
153
+ annotation_types_group = self.create_annotation_layout()
154
+ image_options_group = self.create_image_source_layout()
155
+ negative_samples_group = self.create_negative_samples_layout()
156
+
157
+ options_layout.addWidget(annotation_types_group)
158
+ options_layout.addWidget(image_options_group)
159
+ options_layout.addWidget(negative_samples_group)
160
+
161
+ self.layout.addLayout(options_layout)
162
+
163
+ def create_annotation_layout(self):
164
+ """Creates the annotation type checkboxes layout group box."""
152
165
  group_box = QGroupBox("Annotation Types")
153
- layout = QHBoxLayout()
166
+ layout = QVBoxLayout()
154
167
 
155
168
  self.include_patches_checkbox = QCheckBox("Include Patch Annotations")
156
169
  self.include_rectangles_checkbox = QCheckBox("Include Rectangle Annotations")
@@ -161,30 +174,24 @@ class Base(QDialog):
161
174
  layout.addWidget(self.include_polygons_checkbox)
162
175
 
163
176
  group_box.setLayout(layout)
164
- self.layout.addWidget(group_box)
177
+ return group_box
165
178
 
166
- def setup_options_layout(self):
167
- """Setup the image options layout."""
168
- group_box = QGroupBox("Image Options")
169
- layout = QHBoxLayout() # Changed from QVBoxLayout to QHBoxLayout
179
+ def create_image_source_layout(self):
180
+ """Creates the image source options layout group box."""
181
+ group_box = QGroupBox("Image Source")
182
+ layout = QVBoxLayout()
170
183
 
171
- # Create a button group for the image checkboxes
172
184
  self.image_options_group = QButtonGroup(self)
173
185
 
174
186
  self.all_images_radio = QRadioButton("All Images")
175
187
  self.filtered_images_radio = QRadioButton("Filtered Images")
176
188
 
177
- # Add the radio buttons to the button group
178
189
  self.image_options_group.addButton(self.all_images_radio)
179
190
  self.image_options_group.addButton(self.filtered_images_radio)
180
-
181
- # Ensure only one radio button can be checked at a time
182
191
  self.image_options_group.setExclusive(True)
183
192
 
184
- # Set the default radio button
185
193
  self.all_images_radio.setChecked(True)
186
194
 
187
- # Connect radio button signals
188
195
  self.all_images_radio.toggled.connect(self.update_image_selection)
189
196
  self.filtered_images_radio.toggled.connect(self.update_image_selection)
190
197
 
@@ -192,7 +199,32 @@ class Base(QDialog):
192
199
  layout.addWidget(self.filtered_images_radio)
193
200
 
194
201
  group_box.setLayout(layout)
195
- self.layout.addWidget(group_box)
202
+ return group_box
203
+
204
+ def create_negative_samples_layout(self):
205
+ """Creates the negative sample options layout group box."""
206
+ group_box = QGroupBox("Negative Samples")
207
+ layout = QVBoxLayout()
208
+
209
+ self.negative_samples_group = QButtonGroup(self)
210
+
211
+ self.include_negatives_radio = QRadioButton("Include Negatives")
212
+ self.exclude_negatives_radio = QRadioButton("Exclude Negatives")
213
+
214
+ self.negative_samples_group.addButton(self.include_negatives_radio)
215
+ self.negative_samples_group.addButton(self.exclude_negatives_radio)
216
+ self.negative_samples_group.setExclusive(True)
217
+
218
+ self.exclude_negatives_radio.setChecked(True)
219
+
220
+ # Connect to update stats when changed. Only one needed for the group.
221
+ self.include_negatives_radio.toggled.connect(self.update_summary_statistics)
222
+
223
+ layout.addWidget(self.include_negatives_radio)
224
+ layout.addWidget(self.exclude_negatives_radio)
225
+
226
+ group_box.setLayout(layout)
227
+ return group_box
196
228
 
197
229
  def setup_table_layout(self):
198
230
  """Setup the label counts table layout."""
@@ -424,6 +456,11 @@ class Base(QDialog):
424
456
  else:
425
457
  images = self.image_window.raster_manager.image_paths
426
458
 
459
+ # If "Exclude Negatives" is checked, only use images that have selected annotations.
460
+ if self.exclude_negatives_radio.isChecked():
461
+ image_paths_with_annotations = {a.image_path for a in self.selected_annotations}
462
+ images = [img for img in images if img in image_paths_with_annotations]
463
+
427
464
  random.shuffle(images)
428
465
 
429
466
  train_split = int(len(images) * self.train_ratio)
@@ -551,9 +588,6 @@ class Base(QDialog):
551
588
 
552
589
  self.updating_summary_statistics = True
553
590
 
554
- # Split the data by images
555
- self.split_data()
556
-
557
591
  # Selected labels based on user's selection
558
592
  self.selected_labels = []
559
593
  for row in range(self.label_counts_table.rowCount()):
@@ -564,6 +598,9 @@ class Base(QDialog):
564
598
 
565
599
  # Filter annotations based on the selected annotation types and current tab
566
600
  self.selected_annotations = self.filter_annotations()
601
+
602
+ # Split the data by images
603
+ self.split_data()
567
604
 
568
605
  # Split the data by annotations
569
606
  self.determine_splits()
@@ -704,4 +741,4 @@ class Base(QDialog):
704
741
  raise NotImplementedError("Method must be implemented in the subclass.")
705
742
 
706
743
  def process_annotations(self, annotations, split_dir, split):
707
- raise NotImplementedError("Method must be implemented in the subclass.")
744
+ raise NotImplementedError("Method must be implemented in the subclass.")
@@ -60,6 +60,10 @@ class Classify(Base):
60
60
  self.include_polygons_checkbox.setChecked(True)
61
61
  self.include_polygons_checkbox.setEnabled(True)
62
62
 
63
+ # Disable negative sample options for classification
64
+ self.include_negatives_radio.setEnabled(False)
65
+ self.exclude_negatives_radio.setEnabled(False)
66
+
63
67
  def create_dataset(self, output_dir_path):
64
68
  """
65
69
  Create an image classification dataset.
@@ -219,4 +223,4 @@ class Classify(Base):
219
223
  progress_bar.stop_progress()
220
224
  progress_bar.close()
221
225
  progress_bar = None
222
- gc.collect()
226
+ gc.collect()