coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
- coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
- coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
- coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
- coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
- coralnet_toolbox/CoralNet/QtDownload.py +2 -1
- coralnet_toolbox/Explorer/QtDataItem.py +1 -1
- coralnet_toolbox/Explorer/QtExplorer.py +159 -17
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +160 -86
- coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
- coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
- coralnet_toolbox/IO/QtOpenProject.py +46 -78
- coralnet_toolbox/IO/QtSaveProject.py +18 -43
- coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
- coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
- coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
- coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
- coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
- coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
- coralnet_toolbox/QtAnnotationWindow.py +42 -14
- coralnet_toolbox/QtEventFilter.py +19 -2
- coralnet_toolbox/QtImageWindow.py +134 -86
- coralnet_toolbox/QtLabelWindow.py +14 -2
- coralnet_toolbox/QtMainWindow.py +122 -9
- coralnet_toolbox/QtProgressBar.py +52 -27
- coralnet_toolbox/Rasters/QtRaster.py +59 -7
- coralnet_toolbox/Rasters/RasterTableModel.py +42 -14
- coralnet_toolbox/SAM/QtBatchInference.py +0 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
- coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
- coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1634 -0
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +107 -154
- coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
- coralnet_toolbox/SeeAnything/__init__.py +2 -0
- coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
- coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
- coralnet_toolbox/Tools/QtSAMTool.py +222 -57
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +223 -55
- coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
- coralnet_toolbox/Tools/QtSelectTool.py +27 -3
- coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
- coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
- coralnet_toolbox/Tools/__init__.py +2 -0
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +137 -47
- coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +56 -53
- coralnet_toolbox-0.0.72.dist-info/METADATA +0 -341
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -197,30 +197,50 @@ class ExportTagLabAnnotations:
|
|
197
197
|
"""
|
198
198
|
# Calculate bounding box, centroid, area, perimeter, and contour
|
199
199
|
points = annotation.points
|
200
|
+
# Convert points to TagLab contour format
|
201
|
+
contour = self.taglabToPoints(np.array([[point.x(), point.y()] for point in points]))
|
202
|
+
inner_contours = []
|
203
|
+
if hasattr(annotation, 'holes') and annotation.holes:
|
204
|
+
# Convert holes to TagLab format
|
205
|
+
for hole in annotation.holes:
|
206
|
+
inner_contours.append(self.taglabToPoints(np.array([[point.x(), point.y()] for point in hole])))
|
207
|
+
|
208
|
+
# Calculate bounding box
|
200
209
|
min_x = int(min(point.x() for point in points))
|
201
210
|
min_y = int(min(point.y() for point in points))
|
202
211
|
max_x = int(max(point.x() for point in points))
|
203
|
-
max_y = int(max(point.y() for point in points))
|
212
|
+
max_y = int(max(point.y() for point in points))
|
204
213
|
width = max_x - min_x
|
205
214
|
height = max_y - min_y
|
215
|
+
bbox = [min_x, min_y, width, height]
|
216
|
+
|
217
|
+
# Calculate centroid
|
206
218
|
centroid_x = float(f"{sum(point.x() for point in points) / len(points):.1f}")
|
207
219
|
centroid_y = float(f"{sum(point.y() for point in points) / len(points):.1f}")
|
220
|
+
centroid = [centroid_x, centroid_y]
|
221
|
+
|
208
222
|
area = float(f"{annotation.get_area():.1f}")
|
209
223
|
perimeter = float(f"{annotation.get_perimeter():.1f}")
|
210
|
-
|
224
|
+
data = annotation.data if hasattr(annotation, 'data') else {}
|
225
|
+
|
226
|
+
# Pop these keys from data if they exist
|
227
|
+
class_name = data.pop('class_name', annotation.label.short_label_code)
|
228
|
+
instance_name = data.pop('instance_name', "coral0")
|
229
|
+
blob_name = data.pop('blob_name', f"c-0-{centroid_x}x-{centroid_y}y")
|
230
|
+
note = data.pop('note', "")
|
211
231
|
|
212
232
|
annotation_dict = {
|
213
|
-
"bbox":
|
214
|
-
"centroid":
|
233
|
+
"bbox": bbox,
|
234
|
+
"centroid": centroid,
|
215
235
|
"area": area,
|
216
236
|
"perimeter": perimeter,
|
217
237
|
"contour": contour,
|
218
|
-
"inner contours":
|
219
|
-
"class name":
|
220
|
-
"instance name":
|
221
|
-
"blob name":
|
222
|
-
"note":
|
223
|
-
"data":
|
238
|
+
"inner contours": inner_contours,
|
239
|
+
"class name": class_name,
|
240
|
+
"instance name": instance_name,
|
241
|
+
"blob name": blob_name,
|
242
|
+
"note": note,
|
243
|
+
"data": data
|
224
244
|
}
|
225
245
|
|
226
246
|
return annotation_dict
|
@@ -161,23 +161,25 @@ class ImportTagLabAnnotations:
|
|
161
161
|
short_label_code = label_info['name'].strip()
|
162
162
|
long_label_code = label_info['name'].strip()
|
163
163
|
color = QColor(*label_info['fill'])
|
164
|
-
|
165
|
-
#
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
164
|
+
|
165
|
+
# Pack all other data into a dict
|
166
|
+
imported_data = {
|
167
|
+
'bbox': annotation.get('bbox'),
|
168
|
+
'centroid': annotation.get('centroid'),
|
169
|
+
'area': annotation.get('area'),
|
170
|
+
'perimeter': annotation.get('perimeter'),
|
171
|
+
'class_name': annotation.get('class name'),
|
172
|
+
'instance_name': annotation.get('instance name'),
|
173
|
+
'blob_name': annotation.get('blob name'),
|
174
|
+
'id': annotation.get('id'),
|
175
|
+
'note': annotation.get('note'),
|
176
|
+
'data': annotation.get('data'),
|
177
|
+
}
|
178
178
|
|
179
179
|
# Convert contour string to points
|
180
180
|
points = self.parse_contour(annotation['contour'])
|
181
|
+
# Convert inner contours to a list of lists of points (holes)
|
182
|
+
holes = [self.parse_contour(inner) for inner in annotation.get('inner contours', [])]
|
181
183
|
|
182
184
|
# Create the label if it doesn't exist
|
183
185
|
label = self.label_window.add_label_if_not_exists(short_label_code,
|
@@ -191,8 +193,12 @@ class ImportTagLabAnnotations:
|
|
191
193
|
long_label_code=long_label_code,
|
192
194
|
color=color,
|
193
195
|
image_path=image_full_path,
|
194
|
-
label_id=label_id
|
196
|
+
label_id=label_id,
|
197
|
+
holes=holes,
|
195
198
|
)
|
199
|
+
# Add additional data to the annotation
|
200
|
+
polygon_annotation.data = imported_data
|
201
|
+
|
196
202
|
# Add annotation to the dict
|
197
203
|
self.annotation_window.add_annotation_to_dict(polygon_annotation)
|
198
204
|
|
@@ -12,16 +12,12 @@ from PyQt5.QtWidgets import (QDialog, QFileDialog, QVBoxLayout, QPushButton, QLa
|
|
12
12
|
QLineEdit)
|
13
13
|
|
14
14
|
from coralnet_toolbox.QtLabelWindow import Label
|
15
|
-
|
16
15
|
from coralnet_toolbox.QtWorkArea import WorkArea
|
17
|
-
|
18
16
|
from coralnet_toolbox.Annotations.QtPatchAnnotation import PatchAnnotation
|
19
17
|
from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
|
20
18
|
from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
|
21
19
|
from coralnet_toolbox.Annotations.QtMultiPolygonAnnotation import MultiPolygonAnnotation
|
22
|
-
|
23
20
|
from coralnet_toolbox.Common.QtUpdateImagePaths import UpdateImagePaths
|
24
|
-
|
25
21
|
from coralnet_toolbox.QtProgressBar import ProgressBar
|
26
22
|
|
27
23
|
|
@@ -143,19 +139,18 @@ class OpenProject(QDialog):
|
|
143
139
|
with open(file_path, 'r') as file:
|
144
140
|
project_data = json.load(file)
|
145
141
|
|
142
|
+
# Handle both new and old project formats for images and work areas
|
143
|
+
images_data = project_data.get('images', project_data.get('image_paths'))
|
144
|
+
legacy_workareas = project_data.get('workareas') # For backward compatibility
|
145
|
+
|
146
146
|
# Update main window with loaded project data
|
147
|
-
self.import_images(
|
148
|
-
self.import_workareas(project_data.get('workareas'))
|
147
|
+
self.import_images(images_data, legacy_workareas)
|
149
148
|
self.import_labels(project_data.get('labels'))
|
150
149
|
self.import_annotations(project_data.get('annotations'))
|
151
150
|
|
152
151
|
# Update current project path
|
153
152
|
self.current_project_path = file_path
|
154
153
|
|
155
|
-
QMessageBox.information(self.annotation_window,
|
156
|
-
"Project Loaded",
|
157
|
-
"Project has been successfully loaded.")
|
158
|
-
|
159
154
|
except Exception as e:
|
160
155
|
QMessageBox.warning(self.annotation_window,
|
161
156
|
"Error Loading Project",
|
@@ -168,10 +163,15 @@ class OpenProject(QDialog):
|
|
168
163
|
# Exit
|
169
164
|
self.accept()
|
170
165
|
|
171
|
-
def import_images(self,
|
172
|
-
"""Import images from the given
|
173
|
-
if not
|
166
|
+
def import_images(self, images_data, legacy_workareas=None):
|
167
|
+
"""Import images, states, and work areas from the given data."""
|
168
|
+
if not images_data:
|
174
169
|
return
|
170
|
+
|
171
|
+
# Determine if the format is old (list of strings) or new (list of dicts)
|
172
|
+
is_new_format = isinstance(images_data[0], dict)
|
173
|
+
|
174
|
+
image_paths = [img['path'] for img in images_data] if is_new_format else images_data
|
175
175
|
|
176
176
|
if not all([os.path.exists(path) for path in image_paths]):
|
177
177
|
image_paths, self.updated_paths = UpdateImagePaths.update_paths(image_paths)
|
@@ -183,15 +183,46 @@ class OpenProject(QDialog):
|
|
183
183
|
progress_bar.start_progress(total_images)
|
184
184
|
|
185
185
|
try:
|
186
|
+
# Create a map for quick data lookup if using the new format
|
187
|
+
image_data_map = {img['path']: img for img in images_data} if is_new_format else {}
|
188
|
+
|
186
189
|
# Add images to the image window's raster manager one by one
|
187
190
|
for path in image_paths:
|
188
|
-
# Use the improved add_image method which handles both
|
189
|
-
# adding to raster_manager and updating filtered_paths
|
190
191
|
self.image_window.add_image(path)
|
192
|
+
raster = self.image_window.raster_manager.get_raster(path)
|
193
|
+
if not raster:
|
194
|
+
continue
|
195
|
+
|
196
|
+
# If using the new format, apply saved state and work areas
|
197
|
+
if is_new_format and path in image_data_map:
|
198
|
+
data = image_data_map[path]
|
199
|
+
state = data.get('state', {})
|
200
|
+
work_areas_list = data.get('work_areas', [])
|
201
|
+
|
202
|
+
# Apply raster state
|
203
|
+
raster.checkbox_state = state.get('checkbox_state', False)
|
204
|
+
|
205
|
+
# Import work areas for this image
|
206
|
+
for work_area_data in work_areas_list:
|
207
|
+
try:
|
208
|
+
work_area = WorkArea.from_dict(work_area_data, path)
|
209
|
+
raster.add_work_area(work_area)
|
210
|
+
except Exception as e:
|
211
|
+
print(f"Warning: Could not import work area {work_area_data}: {str(e)}")
|
191
212
|
|
192
213
|
# Update the progress bar
|
193
214
|
progress_bar.update_progress()
|
194
215
|
|
216
|
+
# Handle backward compatibility for old, top-level work areas
|
217
|
+
if legacy_workareas:
|
218
|
+
for image_path, work_areas_list in legacy_workareas.items():
|
219
|
+
current_path = self.updated_paths.get(image_path, image_path)
|
220
|
+
raster = self.image_window.raster_manager.get_raster(current_path)
|
221
|
+
if raster:
|
222
|
+
for work_area_data in work_areas_list:
|
223
|
+
work_area = WorkArea.from_dict(work_area_data, current_path)
|
224
|
+
raster.add_work_area(work_area)
|
225
|
+
|
195
226
|
# Show the last image if any were imported
|
196
227
|
if self.image_window.raster_manager.image_paths:
|
197
228
|
self.image_window.load_image_by_path(self.image_window.raster_manager.image_paths[-1])
|
@@ -204,69 +235,6 @@ class OpenProject(QDialog):
|
|
204
235
|
# Close progress bar
|
205
236
|
progress_bar.stop_progress()
|
206
237
|
progress_bar.close()
|
207
|
-
|
208
|
-
def import_workareas(self, workareas):
|
209
|
-
"""Import work areas for each image."""
|
210
|
-
if not workareas:
|
211
|
-
return
|
212
|
-
|
213
|
-
# Start the progress bar
|
214
|
-
total_images = len(workareas)
|
215
|
-
progress_bar = ProgressBar(self.annotation_window, title="Importing Work Areas")
|
216
|
-
progress_bar.show()
|
217
|
-
progress_bar.start_progress(total_images)
|
218
|
-
|
219
|
-
try:
|
220
|
-
# Loop through each image's work areas
|
221
|
-
for image_path, work_areas_list in workareas.items():
|
222
|
-
|
223
|
-
# Check if the image path was updated (moved)
|
224
|
-
updated_path = False
|
225
|
-
|
226
|
-
if image_path not in self.image_window.raster_manager.image_paths:
|
227
|
-
# Check if the path was updated
|
228
|
-
if image_path in self.updated_paths:
|
229
|
-
image_path = self.updated_paths[image_path]
|
230
|
-
updated_path = True
|
231
|
-
else:
|
232
|
-
print(f"Warning: Image not found for work areas: {image_path}")
|
233
|
-
continue
|
234
|
-
|
235
|
-
# Get the raster for this image
|
236
|
-
raster = self.image_window.raster_manager.get_raster(image_path)
|
237
|
-
if not raster:
|
238
|
-
print(f"Warning: Could not get raster for image: {image_path}")
|
239
|
-
continue
|
240
|
-
|
241
|
-
# Import each work area for this image
|
242
|
-
for work_area_data in work_areas_list:
|
243
|
-
try:
|
244
|
-
# Update image path if it was changed
|
245
|
-
if updated_path:
|
246
|
-
work_area_data['image_path'] = image_path
|
247
|
-
|
248
|
-
# Create WorkArea from dictionary
|
249
|
-
work_area = WorkArea.from_dict(work_area_data, image_path)
|
250
|
-
|
251
|
-
# Add work area to the raster
|
252
|
-
raster.add_work_area(work_area)
|
253
|
-
|
254
|
-
except Exception as e:
|
255
|
-
print(f"Warning: Could not import work area {work_area_data}: {str(e)}")
|
256
|
-
continue
|
257
|
-
|
258
|
-
# Update the progress bar
|
259
|
-
progress_bar.update_progress()
|
260
|
-
|
261
|
-
except Exception as e:
|
262
|
-
QMessageBox.warning(self.annotation_window,
|
263
|
-
"Error Importing Work Areas",
|
264
|
-
f"An error occurred while importing work areas: {str(e)}")
|
265
|
-
|
266
|
-
finally:
|
267
|
-
# Close progress bar
|
268
|
-
progress_bar.stop_progress()
|
269
|
-
progress_bar.close()
|
270
238
|
|
271
239
|
def import_labels(self, labels):
|
272
240
|
"""Import labels from the given list."""
|
@@ -96,10 +96,9 @@ class SaveProject(QDialog):
|
|
96
96
|
|
97
97
|
try:
|
98
98
|
project_data = {
|
99
|
-
'
|
99
|
+
'images': self.get_images(),
|
100
100
|
'labels': self.get_labels(),
|
101
|
-
'annotations': self.get_annotations()
|
102
|
-
'workareas': self.get_workareas()
|
101
|
+
'annotations': self.get_annotations()
|
103
102
|
}
|
104
103
|
|
105
104
|
with open(file_path, 'w') as file:
|
@@ -125,10 +124,10 @@ class SaveProject(QDialog):
|
|
125
124
|
self.accept()
|
126
125
|
|
127
126
|
def get_images(self):
|
128
|
-
"""Get the list of image paths
|
127
|
+
"""Get the list of image objects, including paths, states, and work areas."""
|
129
128
|
# Start the progress bar
|
130
129
|
total_images = len(self.image_window.raster_manager.image_paths)
|
131
|
-
progress_bar = ProgressBar(self.label_window, "Exporting
|
130
|
+
progress_bar = ProgressBar(self.label_window, "Exporting Image Data")
|
132
131
|
progress_bar.show()
|
133
132
|
progress_bar.start_progress(total_images)
|
134
133
|
|
@@ -137,7 +136,19 @@ class SaveProject(QDialog):
|
|
137
136
|
|
138
137
|
# Loop through all of the image paths
|
139
138
|
for image_path in self.image_window.raster_manager.image_paths:
|
140
|
-
|
139
|
+
raster = self.image_window.raster_manager.get_raster(image_path)
|
140
|
+
if raster:
|
141
|
+
# Get work areas for this raster
|
142
|
+
work_areas_list = [wa.to_dict() for wa in raster.get_work_areas()]
|
143
|
+
|
144
|
+
image_data = {
|
145
|
+
'path': image_path,
|
146
|
+
'state': {
|
147
|
+
'checkbox_state': raster.checkbox_state
|
148
|
+
},
|
149
|
+
'work_areas': work_areas_list
|
150
|
+
}
|
151
|
+
export_images.append(image_data)
|
141
152
|
progress_bar.update_progress()
|
142
153
|
|
143
154
|
except Exception as e:
|
@@ -234,42 +245,6 @@ class SaveProject(QDialog):
|
|
234
245
|
|
235
246
|
return export_annotations
|
236
247
|
|
237
|
-
def get_workareas(self):
|
238
|
-
"""Get the work areas to export."""
|
239
|
-
# Start progress bar
|
240
|
-
total_rasters = len(self.image_window.raster_manager.image_paths)
|
241
|
-
progress_bar = ProgressBar(self.annotation_window, title="Exporting Work Areas")
|
242
|
-
progress_bar.show()
|
243
|
-
progress_bar.start_progress(total_rasters)
|
244
|
-
|
245
|
-
try:
|
246
|
-
export_workareas = {}
|
247
|
-
|
248
|
-
# Loop through all rasters to get their work areas
|
249
|
-
for image_path in self.image_window.raster_manager.image_paths:
|
250
|
-
raster = self.image_window.raster_manager.get_raster(image_path)
|
251
|
-
if raster and raster.has_work_areas():
|
252
|
-
work_areas_list = []
|
253
|
-
for work_area in raster.get_work_areas():
|
254
|
-
work_areas_list.append(work_area.to_dict())
|
255
|
-
|
256
|
-
if work_areas_list: # Only add if there are work areas
|
257
|
-
export_workareas[image_path] = work_areas_list
|
258
|
-
|
259
|
-
progress_bar.update_progress()
|
260
|
-
|
261
|
-
except Exception as e:
|
262
|
-
QMessageBox.warning(self.annotation_window,
|
263
|
-
"Error Exporting Work Areas",
|
264
|
-
f"An error occurred while exporting work areas: {str(e)}")
|
265
|
-
|
266
|
-
finally:
|
267
|
-
# Stop the progress bar
|
268
|
-
progress_bar.stop_progress()
|
269
|
-
progress_bar.close()
|
270
|
-
|
271
|
-
return export_workareas
|
272
|
-
|
273
248
|
def get_project_path(self):
|
274
249
|
"""Get the current project path."""
|
275
250
|
return self.current_project_path
|
@@ -290,4 +265,4 @@ class SaveProject(QDialog):
|
|
290
265
|
"""Handle dialog rejection (Cancel or close)"""
|
291
266
|
if self.current_project_path:
|
292
267
|
self.file_path_edit.setText(self.current_project_path)
|
293
|
-
super().reject()
|
268
|
+
super().reject()
|
@@ -123,18 +123,29 @@ class Detect(Base):
|
|
123
123
|
|
124
124
|
def update_sam_task_state(self):
|
125
125
|
"""
|
126
|
-
Centralized method to check if SAM is loaded and update task
|
126
|
+
Centralized method to check if SAM is loaded and update task accordingly.
|
127
|
+
If the user has selected to use SAM, this function ensures the task is set to 'segment'.
|
128
|
+
Crucially, it does NOT alter the task if SAM is not selected, respecting the
|
129
|
+
user's choice from the 'Task' dropdown.
|
127
130
|
"""
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
131
|
+
# Check if the user wants to use the SAM model
|
132
|
+
if self.use_sam_dropdown.currentText() == "True":
|
133
|
+
# SAM is requested. Check if it's actually available.
|
134
|
+
sam_is_available = (
|
135
|
+
hasattr(self, 'sam_dialog') and
|
136
|
+
self.sam_dialog is not None and
|
137
|
+
self.sam_dialog.loaded_model is not None
|
138
|
+
)
|
139
|
+
|
140
|
+
if sam_is_available:
|
141
|
+
# If SAM is wanted and available, the task must be segmentation.
|
142
|
+
self.task = 'segment'
|
143
|
+
else:
|
144
|
+
# If SAM is wanted but not available, revert the dropdown and do nothing else.
|
145
|
+
# The 'is_sam_model_deployed' function already handles showing an error message.
|
146
|
+
self.use_sam_dropdown.setCurrentText("False")
|
147
|
+
|
148
|
+
# If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
|
138
149
|
|
139
150
|
def load_model(self):
|
140
151
|
"""
|
@@ -123,17 +123,29 @@ class Segment(Base):
|
|
123
123
|
|
124
124
|
def update_sam_task_state(self):
|
125
125
|
"""
|
126
|
-
Centralized method to check if SAM is loaded and update task
|
126
|
+
Centralized method to check if SAM is loaded and update task accordingly.
|
127
|
+
If the user has selected to use SAM, this function ensures the task is set to 'segment'.
|
128
|
+
Crucially, it does NOT alter the task if SAM is not selected, respecting the
|
129
|
+
user's choice from the 'Task' dropdown.
|
127
130
|
"""
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
131
|
+
# Check if the user wants to use the SAM model
|
132
|
+
if self.use_sam_dropdown.currentText() == "True":
|
133
|
+
# SAM is requested. Check if it's actually available.
|
134
|
+
sam_is_available = (
|
135
|
+
hasattr(self, 'sam_dialog') and
|
136
|
+
self.sam_dialog is not None and
|
137
|
+
self.sam_dialog.loaded_model is not None
|
138
|
+
)
|
139
|
+
|
140
|
+
if sam_is_available:
|
141
|
+
# If SAM is wanted and available, the task must be segmentation.
|
142
|
+
self.task = 'segment'
|
143
|
+
else:
|
144
|
+
# If SAM is wanted but not available, revert the dropdown and do nothing else.
|
145
|
+
# The 'is_sam_model_deployed' function already handles showing an error message.
|
146
|
+
self.use_sam_dropdown.setCurrentText("False")
|
147
|
+
|
148
|
+
# If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
|
137
149
|
|
138
150
|
def load_model(self):
|
139
151
|
"""
|
@@ -42,7 +42,7 @@ class Base(QDialog):
|
|
42
42
|
self.annotation_window = main_window.annotation_window
|
43
43
|
self.image_window = main_window.image_window
|
44
44
|
|
45
|
-
self.resize(
|
45
|
+
self.resize(800, 800)
|
46
46
|
self.setWindowIcon(get_icon("coral.png"))
|
47
47
|
self.setWindowTitle("Export Dataset")
|
48
48
|
|
@@ -64,10 +64,8 @@ class Base(QDialog):
|
|
64
64
|
self.setup_output_layout()
|
65
65
|
# Setup the ratio layout
|
66
66
|
self.setup_ratio_layout()
|
67
|
-
# Setup the
|
68
|
-
self.
|
69
|
-
# Setup the options layout
|
70
|
-
self.setup_options_layout()
|
67
|
+
# Setup the data selection layout
|
68
|
+
self.setup_data_selection_layout()
|
71
69
|
# Setup the table layout
|
72
70
|
self.setup_table_layout()
|
73
71
|
# Setup the status layout
|
@@ -147,10 +145,25 @@ class Base(QDialog):
|
|
147
145
|
group_box.setLayout(layout)
|
148
146
|
self.layout.addWidget(group_box)
|
149
147
|
|
150
|
-
def
|
151
|
-
"""Setup the
|
148
|
+
def setup_data_selection_layout(self):
|
149
|
+
"""Setup the layout for data selection options in a horizontal arrangement."""
|
150
|
+
options_layout = QHBoxLayout()
|
151
|
+
|
152
|
+
# Create and add the group boxes
|
153
|
+
annotation_types_group = self.create_annotation_layout()
|
154
|
+
image_options_group = self.create_image_source_layout()
|
155
|
+
negative_samples_group = self.create_negative_samples_layout()
|
156
|
+
|
157
|
+
options_layout.addWidget(annotation_types_group)
|
158
|
+
options_layout.addWidget(image_options_group)
|
159
|
+
options_layout.addWidget(negative_samples_group)
|
160
|
+
|
161
|
+
self.layout.addLayout(options_layout)
|
162
|
+
|
163
|
+
def create_annotation_layout(self):
|
164
|
+
"""Creates the annotation type checkboxes layout group box."""
|
152
165
|
group_box = QGroupBox("Annotation Types")
|
153
|
-
layout =
|
166
|
+
layout = QVBoxLayout()
|
154
167
|
|
155
168
|
self.include_patches_checkbox = QCheckBox("Include Patch Annotations")
|
156
169
|
self.include_rectangles_checkbox = QCheckBox("Include Rectangle Annotations")
|
@@ -161,30 +174,24 @@ class Base(QDialog):
|
|
161
174
|
layout.addWidget(self.include_polygons_checkbox)
|
162
175
|
|
163
176
|
group_box.setLayout(layout)
|
164
|
-
|
177
|
+
return group_box
|
165
178
|
|
166
|
-
def
|
167
|
-
"""
|
168
|
-
group_box = QGroupBox("Image
|
169
|
-
layout =
|
179
|
+
def create_image_source_layout(self):
|
180
|
+
"""Creates the image source options layout group box."""
|
181
|
+
group_box = QGroupBox("Image Source")
|
182
|
+
layout = QVBoxLayout()
|
170
183
|
|
171
|
-
# Create a button group for the image checkboxes
|
172
184
|
self.image_options_group = QButtonGroup(self)
|
173
185
|
|
174
186
|
self.all_images_radio = QRadioButton("All Images")
|
175
187
|
self.filtered_images_radio = QRadioButton("Filtered Images")
|
176
188
|
|
177
|
-
# Add the radio buttons to the button group
|
178
189
|
self.image_options_group.addButton(self.all_images_radio)
|
179
190
|
self.image_options_group.addButton(self.filtered_images_radio)
|
180
|
-
|
181
|
-
# Ensure only one radio button can be checked at a time
|
182
191
|
self.image_options_group.setExclusive(True)
|
183
192
|
|
184
|
-
# Set the default radio button
|
185
193
|
self.all_images_radio.setChecked(True)
|
186
194
|
|
187
|
-
# Connect radio button signals
|
188
195
|
self.all_images_radio.toggled.connect(self.update_image_selection)
|
189
196
|
self.filtered_images_radio.toggled.connect(self.update_image_selection)
|
190
197
|
|
@@ -192,7 +199,32 @@ class Base(QDialog):
|
|
192
199
|
layout.addWidget(self.filtered_images_radio)
|
193
200
|
|
194
201
|
group_box.setLayout(layout)
|
195
|
-
|
202
|
+
return group_box
|
203
|
+
|
204
|
+
def create_negative_samples_layout(self):
|
205
|
+
"""Creates the negative sample options layout group box."""
|
206
|
+
group_box = QGroupBox("Negative Samples")
|
207
|
+
layout = QVBoxLayout()
|
208
|
+
|
209
|
+
self.negative_samples_group = QButtonGroup(self)
|
210
|
+
|
211
|
+
self.include_negatives_radio = QRadioButton("Include Negatives")
|
212
|
+
self.exclude_negatives_radio = QRadioButton("Exclude Negatives")
|
213
|
+
|
214
|
+
self.negative_samples_group.addButton(self.include_negatives_radio)
|
215
|
+
self.negative_samples_group.addButton(self.exclude_negatives_radio)
|
216
|
+
self.negative_samples_group.setExclusive(True)
|
217
|
+
|
218
|
+
self.exclude_negatives_radio.setChecked(True)
|
219
|
+
|
220
|
+
# Connect to update stats when changed. Only one needed for the group.
|
221
|
+
self.include_negatives_radio.toggled.connect(self.update_summary_statistics)
|
222
|
+
|
223
|
+
layout.addWidget(self.include_negatives_radio)
|
224
|
+
layout.addWidget(self.exclude_negatives_radio)
|
225
|
+
|
226
|
+
group_box.setLayout(layout)
|
227
|
+
return group_box
|
196
228
|
|
197
229
|
def setup_table_layout(self):
|
198
230
|
"""Setup the label counts table layout."""
|
@@ -424,6 +456,11 @@ class Base(QDialog):
|
|
424
456
|
else:
|
425
457
|
images = self.image_window.raster_manager.image_paths
|
426
458
|
|
459
|
+
# If "Exclude Negatives" is checked, only use images that have selected annotations.
|
460
|
+
if self.exclude_negatives_radio.isChecked():
|
461
|
+
image_paths_with_annotations = {a.image_path for a in self.selected_annotations}
|
462
|
+
images = [img for img in images if img in image_paths_with_annotations]
|
463
|
+
|
427
464
|
random.shuffle(images)
|
428
465
|
|
429
466
|
train_split = int(len(images) * self.train_ratio)
|
@@ -551,9 +588,6 @@ class Base(QDialog):
|
|
551
588
|
|
552
589
|
self.updating_summary_statistics = True
|
553
590
|
|
554
|
-
# Split the data by images
|
555
|
-
self.split_data()
|
556
|
-
|
557
591
|
# Selected labels based on user's selection
|
558
592
|
self.selected_labels = []
|
559
593
|
for row in range(self.label_counts_table.rowCount()):
|
@@ -564,6 +598,9 @@ class Base(QDialog):
|
|
564
598
|
|
565
599
|
# Filter annotations based on the selected annotation types and current tab
|
566
600
|
self.selected_annotations = self.filter_annotations()
|
601
|
+
|
602
|
+
# Split the data by images
|
603
|
+
self.split_data()
|
567
604
|
|
568
605
|
# Split the data by annotations
|
569
606
|
self.determine_splits()
|
@@ -704,4 +741,4 @@ class Base(QDialog):
|
|
704
741
|
raise NotImplementedError("Method must be implemented in the subclass.")
|
705
742
|
|
706
743
|
def process_annotations(self, annotations, split_dir, split):
|
707
|
-
raise NotImplementedError("Method must be implemented in the subclass.")
|
744
|
+
raise NotImplementedError("Method must be implemented in the subclass.")
|
@@ -60,6 +60,10 @@ class Classify(Base):
|
|
60
60
|
self.include_polygons_checkbox.setChecked(True)
|
61
61
|
self.include_polygons_checkbox.setEnabled(True)
|
62
62
|
|
63
|
+
# Disable negative sample options for classification
|
64
|
+
self.include_negatives_radio.setEnabled(False)
|
65
|
+
self.exclude_negatives_radio.setEnabled(False)
|
66
|
+
|
63
67
|
def create_dataset(self, output_dir_path):
|
64
68
|
"""
|
65
69
|
Create an image classification dataset.
|
@@ -219,4 +223,4 @@ class Classify(Base):
|
|
219
223
|
progress_bar.stop_progress()
|
220
224
|
progress_bar.close()
|
221
225
|
progress_bar = None
|
222
|
-
gc.collect()
|
226
|
+
gc.collect()
|