coralnet-toolbox 0.0.75__py2.py3-none-any.whl → 0.0.76__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +57 -12
  2. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +44 -14
  3. coralnet_toolbox/Explorer/transformer_models.py +13 -2
  4. coralnet_toolbox/IO/QtExportMaskAnnotations.py +538 -403
  5. coralnet_toolbox/Icons/system_monitor.png +0 -0
  6. coralnet_toolbox/QtEventFilter.py +4 -4
  7. coralnet_toolbox/QtMainWindow.py +104 -64
  8. coralnet_toolbox/QtProgressBar.py +1 -0
  9. coralnet_toolbox/QtSystemMonitor.py +370 -0
  10. coralnet_toolbox/Results/ConvertResults.py +14 -8
  11. coralnet_toolbox/Results/ResultsProcessor.py +3 -2
  12. coralnet_toolbox/SAM/QtDeployGenerator.py +1 -1
  13. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  14. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +15 -10
  15. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +10 -6
  16. coralnet_toolbox/Tile/QtTileBatchInference.py +4 -4
  17. coralnet_toolbox/Tools/QtSAMTool.py +140 -91
  18. coralnet_toolbox/Transformers/Models/GroundingDINO.py +72 -0
  19. coralnet_toolbox/Transformers/Models/OWLViT.py +72 -0
  20. coralnet_toolbox/Transformers/Models/OmDetTurbo.py +68 -0
  21. coralnet_toolbox/Transformers/Models/QtBase.py +120 -0
  22. coralnet_toolbox/{AutoDistill → Transformers}/Models/__init__.py +1 -1
  23. coralnet_toolbox/{AutoDistill → Transformers}/QtBatchInference.py +15 -15
  24. coralnet_toolbox/{AutoDistill → Transformers}/QtDeployModel.py +18 -16
  25. coralnet_toolbox/{AutoDistill → Transformers}/__init__.py +1 -1
  26. coralnet_toolbox/__init__.py +1 -1
  27. coralnet_toolbox/utilities.py +0 -15
  28. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.76.dist-info}/METADATA +9 -9
  29. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.76.dist-info}/RECORD +33 -31
  30. coralnet_toolbox/AutoDistill/Models/GroundingDINO.py +0 -81
  31. coralnet_toolbox/AutoDistill/Models/OWLViT.py +0 -76
  32. coralnet_toolbox/AutoDistill/Models/OmDetTurbo.py +0 -75
  33. coralnet_toolbox/AutoDistill/Models/QtBase.py +0 -112
  34. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.76.dist-info}/WHEEL +0 -0
  35. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.76.dist-info}/entry_points.txt +0 -0
  36. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.76.dist-info}/licenses/LICENSE.txt +0 -0
  37. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.76.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,370 @@
1
+ import psutil
2
+ import GPUtil
3
+ import collections
4
+ import time
5
+ import platform
6
+
7
+ import pyqtgraph as pg
8
+ from PyQt5.QtWidgets import QMainWindow, QGridLayout, QWidget, QLabel
9
+ from PyQt5.QtCore import Qt, QTimer
10
+ from PyQt5.QtGui import QFont
11
+
12
+ from coralnet_toolbox.Icons import get_icon
13
+
14
+
15
+ # ----------------------------------------------------------------------------------------------------------------------
16
+ # Classes
17
+ # ----------------------------------------------------------------------------------------------------------------------
18
+
19
+
20
+ class SystemMonitor(QMainWindow):
21
+ """
22
+ A PyQt5 application to monitor and display real-time CPU, Memory, and GPU usage.
23
+ The monitoring starts/stops automatically when the window is shown/hidden.
24
+ """
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ # --- Window Properties ---
29
+ self.setWindowTitle("System Monitor")
30
+ self.setWindowIcon(get_icon("system_monitor.png"))
31
+ self.setGeometry(100, 100, 800, 900)
32
+
33
+ # --- Data Storage ---
34
+ # Use collections.deque for efficient, fixed-size data storage
35
+ self.history_size = 100 # Number of data points to display
36
+ self.cpu_data = collections.deque([0] * self.history_size, maxlen=self.history_size)
37
+ self.mem_data = collections.deque([0] * self.history_size, maxlen=self.history_size)
38
+ self.gpu_data = collections.deque([0] * self.history_size, maxlen=self.history_size)
39
+
40
+ # Per-core CPU data
41
+ self.cpu_cores = psutil.cpu_count(logical=True)
42
+ self.per_core_data = []
43
+ for _ in range(self.cpu_cores):
44
+ self.per_core_data.append(collections.deque([0] * self.history_size, maxlen=self.history_size))
45
+
46
+ # Disk I/O data
47
+ self.disk_read_data = collections.deque([0] * self.history_size, maxlen=self.history_size)
48
+ self.disk_write_data = collections.deque([0] * self.history_size, maxlen=self.history_size)
49
+ self.prev_disk_io = psutil.disk_io_counters()
50
+ self.prev_time = time.monotonic()
51
+
52
+ # Network traffic data
53
+ self.net_sent_data = collections.deque([0] * self.history_size, maxlen=self.history_size)
54
+ self.net_recv_data = collections.deque([0] * self.history_size, maxlen=self.history_size)
55
+ self.prev_net_io = psutil.net_io_counters()
56
+
57
+ # X-axis data (time steps)
58
+ self.x_axis = list(range(self.history_size))
59
+
60
+ # --- Initialize UI ---
61
+ self.setup_layout()
62
+
63
+ # --- Timer for Real-Time Updates ---
64
+ # Set up a QTimer to trigger the update_plots method every 1000 ms (1 second)
65
+ # We do NOT start the timer here. It will be started by the showEvent.
66
+ self.timer = QTimer()
67
+ self.timer.setInterval(1000)
68
+ self.timer.timeout.connect(self.update_plots)
69
+
70
+ def setup_layout(self):
71
+ """
72
+ Initializes the user interface, setting up the layout and plots.
73
+ """
74
+ # --- Central Widget and Layout ---
75
+ central_widget = QWidget()
76
+ self.setCentralWidget(central_widget)
77
+ layout = QGridLayout(central_widget)
78
+ layout.setVerticalSpacing(20) # Add more vertical spacing between plots
79
+
80
+ # --- Styling ---
81
+ pg.setConfigOption('background', 'w') # White background for graphs
82
+ pg.setConfigOption('foreground', '#000000') # Black foreground elements in graphs
83
+ plot_label_style = {"color": "#000000", "font-size": "14pt"} # Black text for plot labels
84
+ title_style = {"color": "#000000", "font-size": "16pt", "font-weight": "bold"} # Black text
85
+
86
+ # --- Stats Labels - Only Hardware Info ---
87
+ self.stats_widget = QWidget()
88
+ stats_layout = QGridLayout(self.stats_widget)
89
+ stats_layout.setHorizontalSpacing(20) # Add horizontal spacing between items
90
+ stats_layout.setContentsMargins(10, 10, 10, 20) # Add margins around the widget
91
+
92
+ # Get system hardware information
93
+ cpu_info = platform.processor() or "CPU"
94
+ total_memory = round(psutil.virtual_memory().total / (1024**3), 1)
95
+
96
+ try:
97
+ gpus = GPUtil.getGPUs()
98
+ gpu_info = gpus[0].name if gpus else "No GPU detected"
99
+ except Exception:
100
+ gpu_info = "GPU information unavailable"
101
+
102
+ # Hardware info labels with consistent styling and alignment
103
+ info_label_style = "color: black; font-size: 11pt; font-weight: bold; padding: 5px;"
104
+
105
+ self.cpu_info_label = QLabel(f"CPU: {cpu_info}")
106
+ self.cpu_info_label.setStyleSheet(info_label_style)
107
+ self.cpu_info_label.setAlignment(Qt.AlignCenter)
108
+ stats_layout.addWidget(self.cpu_info_label, 0, 0)
109
+
110
+ self.mem_info_label = QLabel(f"Memory: {total_memory:.1f} GB")
111
+ self.mem_info_label.setStyleSheet(info_label_style)
112
+ self.mem_info_label.setAlignment(Qt.AlignCenter)
113
+ stats_layout.addWidget(self.mem_info_label, 0, 1)
114
+
115
+ self.gpu_info_label = QLabel(f"GPU: {gpu_info}")
116
+ self.gpu_info_label.setStyleSheet(info_label_style)
117
+ self.gpu_info_label.setAlignment(Qt.AlignCenter)
118
+ stats_layout.addWidget(self.gpu_info_label, 0, 2)
119
+
120
+ # Set column stretching to distribute space evenly
121
+ stats_layout.setColumnStretch(0, 1)
122
+ stats_layout.setColumnStretch(1, 1)
123
+ stats_layout.setColumnStretch(2, 1)
124
+
125
+ layout.addWidget(self.stats_widget, 0, 0, 1, 2)
126
+
127
+ # Helper function to add bottom space for x-axis and enable auto-range
128
+ def configure_plot(plot_widget, fixed_y_range=None):
129
+ plot_widget.getPlotItem().getAxis('bottom').setHeight(40)
130
+ plot_widget.getPlotItem().layout.setContentsMargins(10, 10, 10, 30)
131
+
132
+ # If fixed_y_range is None, enable auto-ranging, otherwise set fixed range
133
+ if fixed_y_range is None:
134
+ plot_widget.enableAutoRange(axis='y')
135
+ else:
136
+ min_val, max_val = fixed_y_range
137
+ plot_widget.setYRange(min_val, max_val, padding=0.05)
138
+
139
+ return plot_widget
140
+
141
+ # Helper function to create a value label for a plot
142
+ def create_value_label(plot_widget, initial_text="Current: N/A"):
143
+ text_item = pg.TextItem(text=initial_text, color='#000000', anchor=(0, 0))
144
+ font = QFont()
145
+ font.setBold(True)
146
+ text_item.setFont(font)
147
+ plot_widget.addItem(text_item)
148
+ # Position at top-left corner with some margin
149
+ text_item.setPos(10, 10)
150
+ return text_item
151
+
152
+ # --- CPU Plot --- (fixed 0-100 range)
153
+ self.cpu_plot_widget = pg.PlotWidget()
154
+ self.cpu_plot_widget.setTitle("CPU Usage (%)", **title_style)
155
+ self.cpu_plot_widget.setLabel("left", "Usage", units="%", **plot_label_style)
156
+ self.cpu_plot_widget.setLabel("bottom", "Time (s)", **plot_label_style)
157
+ self.cpu_plot_widget.showGrid(x=True, y=True, alpha=0.3)
158
+ self.cpu_curve = self.cpu_plot_widget.plot(self.x_axis, list(self.cpu_data), pen=pg.mkPen('#00A3FF', width=3))
159
+ self.cpu_value_label = create_value_label(self.cpu_plot_widget, "Current: 0%")
160
+ configure_plot(self.cpu_plot_widget, fixed_y_range=(0, 100))
161
+ layout.addWidget(self.cpu_plot_widget, 1, 0)
162
+
163
+ # --- Per-Core CPU Plot --- (fixed 0-100 range)
164
+ self.core_plot_widget = pg.PlotWidget()
165
+ self.core_plot_widget.setTitle("Per-Core CPU Usage (%)", **title_style)
166
+ self.core_plot_widget.setLabel("left", "Usage", units="%", **plot_label_style)
167
+ self.core_plot_widget.setLabel("bottom", "Time (s)", **plot_label_style)
168
+ self.core_plot_widget.showGrid(x=True, y=True, alpha=0.3)
169
+
170
+ self.per_core_curves = []
171
+ for i in range(self.cpu_cores):
172
+ color = pg.intColor(i, hues=self.cpu_cores)
173
+ self.per_core_curves.append(
174
+ self.core_plot_widget.plot(
175
+ self.x_axis,
176
+ list(self.per_core_data[i]),
177
+ pen=pg.mkPen(color, width=2)
178
+ )
179
+ )
180
+ configure_plot(self.core_plot_widget, fixed_y_range=(0, 100))
181
+ layout.addWidget(self.core_plot_widget, 1, 1)
182
+
183
+ # --- Memory Plot --- (fixed 0-100 range)
184
+ self.mem_plot_widget = pg.PlotWidget()
185
+ self.mem_plot_widget.setTitle("Memory Usage (%)", **title_style)
186
+ self.mem_plot_widget.setLabel("left", "Usage", units="%", **plot_label_style)
187
+ self.mem_plot_widget.setLabel("bottom", "Time (s)", **plot_label_style)
188
+ self.mem_plot_widget.showGrid(x=True, y=True, alpha=0.3)
189
+ self.mem_curve = self.mem_plot_widget.plot(self.x_axis, list(self.mem_data), pen=pg.mkPen('#FF5733', width=3))
190
+ self.mem_value_label = create_value_label(self.mem_plot_widget, "Current: 0%")
191
+ configure_plot(self.mem_plot_widget, fixed_y_range=(0, 100))
192
+ layout.addWidget(self.mem_plot_widget, 2, 0)
193
+
194
+ # --- Disk I/O Plot --- (auto-ranging)
195
+ self.disk_plot_widget = pg.PlotWidget()
196
+ self.disk_plot_widget.setTitle("Disk I/O (MB/s)", **title_style)
197
+ self.disk_plot_widget.setLabel("left", "Transfer Rate", units="MB/s", **plot_label_style)
198
+ self.disk_plot_widget.setLabel("bottom", "Time (s)", **plot_label_style)
199
+ self.disk_plot_widget.showGrid(x=True, y=True, alpha=0.3)
200
+
201
+ self.disk_read_curve = self.disk_plot_widget.plot(
202
+ self.x_axis,
203
+ list(self.disk_read_data),
204
+ pen=pg.mkPen('#FFDD33', width=3)
205
+ )
206
+ self.disk_write_curve = self.disk_plot_widget.plot(
207
+ self.x_axis,
208
+ list(self.disk_write_data),
209
+ pen=pg.mkPen('#33DDFF', width=3)
210
+ )
211
+
212
+ self.disk_value_label = create_value_label(self.disk_plot_widget, "R: 0 MB/s, W: 0 MB/s")
213
+ configure_plot(self.disk_plot_widget, fixed_y_range=None) # Auto-range
214
+ layout.addWidget(self.disk_plot_widget, 2, 1)
215
+
216
+ # --- GPU Plot --- (fixed 0-100 range)
217
+ self.gpu_plot_widget = pg.PlotWidget()
218
+ self.gpu_plot_widget.setTitle("GPU Usage (%)", **title_style)
219
+ self.gpu_plot_widget.setLabel("left", "Usage", units="%", **plot_label_style)
220
+ self.gpu_plot_widget.setLabel("bottom", "Time (s)", **plot_label_style)
221
+ self.gpu_plot_widget.showGrid(x=True, y=True, alpha=0.3)
222
+ self.gpu_curve = self.gpu_plot_widget.plot(self.x_axis, list(self.gpu_data), pen=pg.mkPen('#33FF57', width=3))
223
+ self.gpu_value_label = create_value_label(self.gpu_plot_widget, "Current: 0%")
224
+ configure_plot(self.gpu_plot_widget, fixed_y_range=(0, 100))
225
+ layout.addWidget(self.gpu_plot_widget, 3, 0)
226
+
227
+ # --- Network Traffic Plot --- (auto-ranging)
228
+ self.net_plot_widget = pg.PlotWidget()
229
+ self.net_plot_widget.setTitle("Network Traffic (MB/s)", **title_style)
230
+ self.net_plot_widget.setLabel("left", "Transfer Rate", units="MB/s", **plot_label_style)
231
+ self.net_plot_widget.setLabel("bottom", "Time (s)", **plot_label_style)
232
+ self.net_plot_widget.showGrid(x=True, y=True, alpha=0.3)
233
+
234
+ self.net_sent_curve = self.net_plot_widget.plot(
235
+ self.x_axis,
236
+ list(self.net_sent_data),
237
+ pen=pg.mkPen('#FF33DD', width=3)
238
+ )
239
+ self.net_recv_curve = self.net_plot_widget.plot(
240
+ self.x_axis,
241
+ list(self.net_recv_data),
242
+ pen=pg.mkPen('#DD33FF', width=3)
243
+ )
244
+
245
+ self.net_value_label = create_value_label(self.net_plot_widget, "Up: 0 MB/s, Down: 0 MB/s")
246
+ configure_plot(self.net_plot_widget, fixed_y_range=None) # Auto-range
247
+ layout.addWidget(self.net_plot_widget, 3, 1)
248
+
249
+ def update_plots(self):
250
+ """
251
+ Fetches new system stats and updates the plot data.
252
+ This method is called by the QTimer.
253
+ """
254
+ curr_time = time.monotonic()
255
+ time_delta = curr_time - self.prev_time
256
+
257
+ # --- Get CPU, Memory and GPU percentages ---
258
+ # CPU usage
259
+ cpu_percent = psutil.cpu_percent()
260
+ self.cpu_data.append(cpu_percent)
261
+ self.cpu_value_label.setText(f"Current: {cpu_percent:.1f}%")
262
+
263
+ # Per-core CPU usage
264
+ per_core = psutil.cpu_percent(percpu=True)
265
+ for i, usage in enumerate(per_core):
266
+ if i < len(self.per_core_data):
267
+ self.per_core_data[i].append(usage)
268
+ self.per_core_curves[i].setData(self.x_axis, list(self.per_core_data[i]))
269
+
270
+ # Memory usage
271
+ mem_percent = psutil.virtual_memory().percent
272
+ self.mem_data.append(mem_percent)
273
+ self.mem_value_label.setText(f"Current: {mem_percent:.1f}%")
274
+
275
+ # GPU usage
276
+ try:
277
+ gpus = GPUtil.getGPUs()
278
+ if gpus:
279
+ gpu_percent = gpus[0].load * 100
280
+ self.gpu_value_label.setText(f"Current: {gpu_percent:.1f}%")
281
+ else:
282
+ gpu_percent = 0
283
+ self.gpu_value_label.setText("Current: 0%")
284
+ except Exception:
285
+ gpu_percent = 0
286
+ self.gpu_value_label.setText("Current: 0%")
287
+
288
+ self.gpu_data.append(gpu_percent)
289
+
290
+ # Disk I/O
291
+ curr_disk_io = psutil.disk_io_counters()
292
+
293
+ # Calculate read/write rates in MB/s
294
+ if curr_disk_io and self.prev_disk_io:
295
+ read_bytes = curr_disk_io.read_bytes - self.prev_disk_io.read_bytes
296
+ write_bytes = curr_disk_io.write_bytes - self.prev_disk_io.write_bytes
297
+ read_mb_per_sec = (read_bytes / time_delta) / (1024**2)
298
+ write_mb_per_sec = (write_bytes / time_delta) / (1024**2)
299
+
300
+ self.disk_read_data.append(read_mb_per_sec)
301
+ self.disk_write_data.append(write_mb_per_sec)
302
+
303
+ # Update the value label with current values
304
+ self.disk_value_label.setText(f"R: {read_mb_per_sec:.1f} MB/s, W: {write_mb_per_sec:.1f} MB/s")
305
+
306
+ # Let pyqtgraph handle auto-ranging automatically
307
+ self.disk_plot_widget.enableAutoRange(axis='y')
308
+
309
+ self.prev_disk_io = curr_disk_io
310
+
311
+ # Network traffic
312
+ curr_net_io = psutil.net_io_counters()
313
+ if curr_net_io and self.prev_net_io:
314
+ sent_bytes = curr_net_io.bytes_sent - self.prev_net_io.bytes_sent
315
+ recv_bytes = curr_net_io.bytes_recv - self.prev_net_io.bytes_recv
316
+
317
+ sent_mb_per_sec = (sent_bytes / time_delta) / (1024**2)
318
+ recv_mb_per_sec = (recv_bytes / time_delta) / (1024**2)
319
+
320
+ self.net_sent_data.append(sent_mb_per_sec)
321
+ self.net_recv_data.append(recv_mb_per_sec)
322
+
323
+ # Update the value label with current values
324
+ self.net_value_label.setText(f"Up: {sent_mb_per_sec:.1f} MB/s, Down: {recv_mb_per_sec:.1f} MB/s")
325
+
326
+ # Let pyqtgraph handle auto-ranging automatically
327
+ self.net_plot_widget.enableAutoRange(axis='y')
328
+
329
+ self.prev_net_io = curr_net_io
330
+
331
+ # --- Update Plot Curves ---
332
+ # Convert deque to list for plotting
333
+ self.cpu_curve.setData(self.x_axis, list(self.cpu_data))
334
+ self.mem_curve.setData(self.x_axis, list(self.mem_data))
335
+ self.gpu_curve.setData(self.x_axis, list(self.gpu_data))
336
+ self.disk_read_curve.setData(self.x_axis, list(self.disk_read_data))
337
+ self.disk_write_curve.setData(self.x_axis, list(self.disk_write_data))
338
+ self.net_sent_curve.setData(self.x_axis, list(self.net_sent_data))
339
+ self.net_recv_curve.setData(self.x_axis, list(self.net_recv_data))
340
+
341
+ self.prev_time = curr_time
342
+
343
+ # --- Event Handlers to Control Monitoring ---
344
+
345
+ def showEvent(self, event):
346
+ """
347
+ Overrides the QWidget's showEvent.
348
+ Starts the timer when the window is shown.
349
+ """
350
+ super().showEvent(event)
351
+ if not self.timer.isActive():
352
+ self.timer.start()
353
+
354
+ def hideEvent(self, event):
355
+ """
356
+ Overrides the QWidget's hideEvent.
357
+ Stops the timer when the window is hidden (e.g., minimized).
358
+ """
359
+ super().hideEvent(event)
360
+ if self.timer.isActive():
361
+ self.timer.stop()
362
+
363
+ def closeEvent(self, event):
364
+ """
365
+ Overrides the QMainWindow's closeEvent.
366
+ Ensures the timer is stopped when the window is closed.
367
+ """
368
+ if self.timer.isActive():
369
+ self.timer.stop()
370
+ super().closeEvent(event)
@@ -94,7 +94,7 @@ class ConvertResults:
94
94
  names: Dictionary mapping class ids to class names (optional)
95
95
 
96
96
  Returns:
97
- generator: Yields Ultralytics Results objects
97
+ list: List of Ultralytics Results objects
98
98
  """
99
99
  # Convert single inputs to lists
100
100
  if not isinstance(detections, list):
@@ -108,8 +108,15 @@ class ConvertResults:
108
108
  if not image_paths:
109
109
  image_paths = [None] * len(images)
110
110
 
111
+ results_list = []
111
112
  for detection, image, path in zip(detections, images, image_paths):
112
113
  # Ensure image is numpy array
114
+ if isinstance(image, str):
115
+ # If image is a string (path), we can't proceed because we need the actual image data
116
+ raise ValueError(
117
+ f"Expected image array, got path string: {image}. Please load the image before passing."
118
+ )
119
+
113
120
  if torch.is_tensor(image):
114
121
  image = image.cpu().numpy()
115
122
 
@@ -118,7 +125,8 @@ class ConvertResults:
118
125
  names = {i: str(i) for i in range(len(detection))} if len(detection) > 0 else {}
119
126
 
120
127
  if len(detection) == 0:
121
- return [Results(orig_img=image, path=path, names=names)]
128
+ results_list.append(Results(orig_img=image, path=path, names=names))
129
+ continue
122
130
 
123
131
  # Handle masks if present
124
132
  if hasattr(detection, 'mask') and detection.mask is not None:
@@ -142,11 +150,9 @@ class ConvertResults:
142
150
  scaled_boxes = scaled_boxes.unsqueeze(0)
143
151
  scaled_boxes = torch.cat([scaled_boxes, scores, cls], dim=1)
144
152
 
145
- # Create and return Results object
146
- return [Results(image,
147
- path=path,
148
- names=names,
149
- boxes=scaled_boxes,
150
- masks=scaled_masks)]
153
+ # Create Results object
154
+ results_list.append(Results(image, path=path, names=names, boxes=scaled_boxes, masks=scaled_masks))
155
+
156
+ return results_list
151
157
 
152
158
 
@@ -1,6 +1,7 @@
1
1
  from PyQt5.QtCore import QPointF
2
2
 
3
3
  from torchvision.ops import nms
4
+ from ultralytics.utils.nms import TorchNMS
4
5
 
5
6
  from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
6
7
  from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
@@ -62,7 +63,7 @@ class ResultsProcessor:
62
63
  def filter_by_iou(self, results):
63
64
  """Filter the results based on the IoU threshold."""
64
65
  try:
65
- results = results[nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh)]
66
+ results = results[TorchNMS.fast_nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh)]
66
67
  except Exception as e:
67
68
  print(f"Warning: Failed to filter results by IoU\n{e}")
68
69
 
@@ -106,7 +107,7 @@ class ResultsProcessor:
106
107
  Get the indices of results that pass the IoU threshold.
107
108
  """
108
109
  try:
109
- indices = nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh).tolist()
110
+ indices = TorchNMS.fast_nms(results.boxes.xyxy, results.boxes.conf, self.iou_thresh).tolist()
110
111
  except Exception as e:
111
112
  print(f"Warning: Failed to get indices for IoU\n{e}")
112
113
  indices = []
@@ -36,7 +36,7 @@ class DeployGeneratorDialog(QDialog):
36
36
 
37
37
  def __init__(self, main_window, parent=None):
38
38
  """
39
- Initialize the AutoDistillDeployModelDialog.
39
+ Initialize the DeployGeneratorDialog.
40
40
 
41
41
  Args:
42
42
  main_window: The main application window.
@@ -143,6 +143,12 @@ class DeployPredictorDialog(QDialog):
143
143
  group_box = QGroupBox("Parameters")
144
144
  layout = QFormLayout()
145
145
 
146
+ # Output type dropdown (polygon or rectangle)
147
+ self.output_type_dropdown = QComboBox()
148
+ self.output_type_dropdown.addItems(["Polygon", "Rectangle"])
149
+ self.output_type_dropdown.setCurrentIndex(0) # Default to Polygon
150
+ layout.addRow("Output Type:", self.output_type_dropdown)
151
+
146
152
  # Allow holes dropdown
147
153
  self.allow_holes_dropdown = QComboBox()
148
154
  self.allow_holes_dropdown.addItems(["True", "False"])
@@ -246,6 +252,10 @@ class DeployPredictorDialog(QDialog):
246
252
  """Return the current setting for allowing holes."""
247
253
  return self.allow_holes_dropdown.currentText() == "True"
248
254
 
255
+ def get_output_type(self):
256
+ """Return the current setting for output type."""
257
+ return self.output_type_dropdown.currentText()
258
+
249
259
  def initialize_uncertainty_threshold(self):
250
260
  """Initialize the uncertainty threshold slider with the current value"""
251
261
  current_value = self.main_window.get_uncertainty_thresh()
@@ -10,10 +10,10 @@ import torch
10
10
  from torch.cuda import empty_cache
11
11
 
12
12
  import pyqtgraph as pg
13
- from pyqtgraph.Qt import QtGui
14
13
 
15
14
  from ultralytics import YOLOE
16
15
  from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
16
+ from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
17
17
 
18
18
  from PyQt5.QtCore import Qt
19
19
  from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
@@ -90,7 +90,7 @@ class DeployGeneratorDialog(QDialog):
90
90
  self.imported_vpes = [] # VPEs loaded from file
91
91
  self.reference_vpes = [] # VPEs created from reference images
92
92
 
93
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ self.device = None # Will be set in showEvent
94
94
 
95
95
  # Main vertical layout for the dialog
96
96
  self.layout = QVBoxLayout(self)
@@ -195,6 +195,8 @@ class DeployGeneratorDialog(QDialog):
195
195
  self.initialize_iou_threshold()
196
196
  self.initialize_area_threshold()
197
197
 
198
+ # Update the device
199
+ self.device = self.main_window.device
198
200
  # Configure the image window's UI elements for this specific dialog
199
201
  self.configure_image_window_for_dialog()
200
202
  # Sync with main window's images BEFORE updating labels
@@ -767,20 +769,20 @@ class DeployGeneratorDialog(QDialog):
767
769
  try:
768
770
  # Load the VPE file
769
771
  loaded_data = torch.load(file_path)
770
-
771
- # TODO Move tensors to the appropriate device
772
- # device = self.main_window.device
772
+
773
+ # Move tensors to the appropriate device
774
+ device = self.main_window.device
773
775
 
774
776
  # Check format type and handle appropriately
775
777
  if isinstance(loaded_data, list):
776
778
  # New format: list of VPE tensors
777
- self.imported_vpes = [vpe.to(self.device) for vpe in loaded_data]
779
+ self.imported_vpes = [vpe.to(device) for vpe in loaded_data]
778
780
  vpe_count = len(self.imported_vpes)
779
781
  self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
780
782
 
781
783
  elif isinstance(loaded_data, torch.Tensor):
782
784
  # Legacy format: single tensor - convert to list for consistency
783
- loaded_vpe = loaded_data.to(self.device)
785
+ loaded_vpe = loaded_data.to(device)
784
786
  # Store as a single-item list
785
787
  self.imported_vpes = [loaded_vpe]
786
788
  self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
@@ -948,7 +950,7 @@ class DeployGeneratorDialog(QDialog):
948
950
  self.model_path = self.model_combo.currentText()
949
951
 
950
952
  # Load model using registry
951
- self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device) # TODO
953
+ self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device)
952
954
 
953
955
  # Create a dummy visual dictionary for standard model loading
954
956
  visual_prompts = dict(
@@ -966,7 +968,7 @@ class DeployGeneratorDialog(QDialog):
966
968
  self.loaded_model.predict(
967
969
  np.zeros((640, 640, 3), dtype=np.uint8),
968
970
  visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
969
- predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
971
+ predictor=YOLOEVPDetectPredictor if self.task == "detect" else YOLOEVPSegPredictor,
970
972
  imgsz=640,
971
973
  conf=0.99,
972
974
  )
@@ -1108,6 +1110,9 @@ class DeployGeneratorDialog(QDialog):
1108
1110
  # We will predict on the first work area/full image.
1109
1111
  input_image = inputs[0]
1110
1112
 
1113
+ # Set the predictor
1114
+ predictor = YOLOEVPDetectPredictor if self.task == "detect" else YOLOEVPSegPredictor
1115
+
1111
1116
  # Iterate through each reference image and its annotations
1112
1117
  for ref_path, ref_annotations in reference_dict.items():
1113
1118
  # The 'refer_image' parameter is the path to the current reference image
@@ -1123,7 +1128,7 @@ class DeployGeneratorDialog(QDialog):
1123
1128
  results = self.loaded_model.predict(input_image,
1124
1129
  refer_image=ref_path,
1125
1130
  visual_prompts=visual_prompts,
1126
- predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
1131
+ predictor=predictor,
1127
1132
  imgsz=self.imgsz_spinbox.value(),
1128
1133
  conf=self.main_window.get_uncertainty_thresh(),
1129
1134
  iou=self.main_window.get_iou_thresh(),
@@ -1,6 +1,5 @@
1
1
  import warnings
2
2
 
3
- import os
4
3
  import gc
5
4
 
6
5
  import numpy as np
@@ -14,7 +13,6 @@ from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
14
13
  from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
15
14
 
16
15
  from PyQt5.QtCore import Qt
17
- from PyQt5.QtGui import QColor
18
16
  from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
19
17
  QHBoxLayout, QLabel, QMessageBox, QPushButton,
20
18
  QSlider, QSpinBox, QVBoxLayout, QGroupBox,
@@ -408,7 +406,7 @@ class DeployPredictorDialog(QDialog):
408
406
  self.loaded_model.predict(
409
407
  np.zeros((640, 640, 3), dtype=np.uint8),
410
408
  visual_prompts=visuals.copy(), # This needs to happen to properly initialize the predictor
411
- predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
409
+ predictor=YOLOEVPDetectPredictor if self.task == 'detect' else YOLOEVPSegPredictor,
412
410
  imgsz=640,
413
411
  conf=0.99,
414
412
  )
@@ -545,12 +543,15 @@ class DeployPredictorDialog(QDialog):
545
543
 
546
544
  # Get the scaled visual prompts
547
545
  visual_prompts = self.scale_prompts(bboxes, masks)
546
+
547
+ # Set the predictor
548
+ predictor=YOLOEVPDetectPredictor if self.task == 'detect' else YOLOEVPSegPredictor
548
549
 
549
550
  try:
550
551
  # Make predictions
551
552
  results = self.loaded_model.predict(self.resized_image,
552
- visual_prompts=visual_prompts.copy(),
553
- predictor=YOLOEVPSegPredictor,
553
+ visual_prompts=visual_prompts.copy(),
554
+ predictor=predictor,
554
555
  imgsz=max(self.resized_image.shape[:2]),
555
556
  conf=self.main_window.get_uncertainty_thresh(),
556
557
  iou=self.main_window.get_iou_thresh(),
@@ -615,6 +616,9 @@ class DeployPredictorDialog(QDialog):
615
616
  progress_bar = ProgressBar(self.annotation_window, title="Making Predictions")
616
617
  progress_bar.show()
617
618
  progress_bar.start_progress(len(target_images))
619
+
620
+ # Set the predictor
621
+ predictor = YOLOEVPDetectPredictor if self.task == 'detect' else YOLOEVPSegPredictor
618
622
 
619
623
  for target_image in target_images:
620
624
 
@@ -623,7 +627,7 @@ class DeployPredictorDialog(QDialog):
623
627
  results = self.loaded_model.predict(target_image,
624
628
  refer_image=refer_image,
625
629
  visual_prompts=visual_prompts.copy(),
626
- predictor=YOLOEVPSegPredictor,
630
+ predictor=predictor,
627
631
  imgsz=self.imgsz_spinbox.value(),
628
632
  conf=self.main_window.get_uncertainty_thresh(),
629
633
  iou=self.main_window.get_iou_thresh(),
@@ -37,8 +37,8 @@ class TileBatchInference(QDialog):
37
37
  self.detect_dialog = main_window.detect_deploy_model_dialog
38
38
  self.segment_dialog = main_window.segment_deploy_model_dialog
39
39
  self.sam_dialog = main_window.sam_deploy_generator_dialog
40
- self.autodistill_dialog = main_window.auto_distill_deploy_model_dialog
41
-
40
+ self.transformers_dialog = main_window.transformers_deploy_model_dialog
41
+
42
42
  # Create a dictionary of the different model dialogs and their loaded models
43
43
  self.model_dialogs = {}
44
44
 
@@ -167,8 +167,8 @@ class TileBatchInference(QDialog):
167
167
  self.model_dialogs["Segment"] = self.segment_dialog
168
168
  if self.sam_dialog and getattr(self.sam_dialog, "loaded_model", None):
169
169
  self.model_dialogs["SAM Generator"] = self.sam_dialog
170
- if self.autodistill_dialog and getattr(self.autodistill_dialog, "loaded_model", None):
171
- self.model_dialogs["Autodistill"] = self.autodistill_dialog
170
+ if self.transformers_dialog and getattr(self.transformers_dialog, "loaded_model", None):
171
+ self.model_dialogs["Transformers"] = self.transformers_dialog
172
172
 
173
173
  # Update the model combo box with the available models
174
174
  self.update_model_combo()