lazylabel-gui 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1173 @@
1
+ """Multi view mode handler."""
2
+
3
+ from PyQt6.QtCore import QPointF, Qt
4
+ from PyQt6.QtGui import QBrush, QColor, QPen, QPolygonF
5
+ from PyQt6.QtWidgets import QGraphicsEllipseItem, QGraphicsRectItem
6
+
7
+ from lazylabel.utils.logger import logger
8
+
9
+ from ...utils import mask_to_pixmap
10
+ from ..hoverable_pixelmap_item import HoverablePixmapItem
11
+ from ..hoverable_polygon_item import HoverablePolygonItem
12
+ from .base_mode import BaseModeHandler
13
+
14
+
15
+ class MultiViewModeHandler(BaseModeHandler):
16
+ """Handler for multi view mode operations."""
17
+
18
+ def __init__(self, main_window):
19
+ super().__init__(main_window)
20
+ # Initialize multi-view segment tracking
21
+ if not hasattr(main_window, "multi_view_segment_items"):
22
+ # Initialize with dynamic viewer count
23
+ num_viewers = self._get_num_viewers()
24
+ main_window.multi_view_segment_items = {i: {} for i in range(num_viewers)}
25
+
26
+ def _get_num_viewers(self):
27
+ """Get the number of viewers based on current configuration."""
28
+ if hasattr(self.main_window, "multi_view_viewers"):
29
+ return len(self.main_window.multi_view_viewers)
30
+ else:
31
+ # Fallback to settings
32
+ config = self.main_window._get_multi_view_config()
33
+ return config["num_viewers"]
34
+
35
+ def _get_other_viewer_indices(self, current_viewer_index):
36
+ """Get indices of all other viewers (excluding current)."""
37
+ num_viewers = self._get_num_viewers()
38
+ return [i for i in range(num_viewers) if i != current_viewer_index]
39
+
40
+ def handle_ai_click(self, pos, event, viewer_index=0):
41
+ """Handle AI mode click in multi view."""
42
+ # Check if models need to be initialized (first time use)
43
+ if (
44
+ not hasattr(self.main_window, "multi_view_models")
45
+ or not self.main_window.multi_view_models
46
+ or all(m is None for m in self.main_window.multi_view_models)
47
+ ):
48
+ # Check if already initializing
49
+ if (
50
+ hasattr(self.main_window, "multi_view_init_worker")
51
+ and self.main_window.multi_view_init_worker
52
+ and self.main_window.multi_view_init_worker.isRunning()
53
+ ):
54
+ return # Don't show duplicate messages
55
+
56
+ self.main_window._show_notification(
57
+ "Initializing AI models for multi-view mode...", duration=0
58
+ ) # Persistent message
59
+ self.main_window._initialize_multi_view_models()
60
+ return # Exit early, models will load in background
61
+
62
+ if viewer_index >= len(self.main_window.multi_view_models):
63
+ return
64
+
65
+ # Check if the specific viewer's model is updating
66
+ if self.main_window.multi_view_models_updating[viewer_index]:
67
+ # This specific model is still loading - user should wait
68
+ self.main_window._show_notification(
69
+ f"AI model for viewer {viewer_index + 1} is still loading...",
70
+ duration=2000,
71
+ )
72
+ return
73
+
74
+ # Check if the specific viewer's model needs the current image loaded
75
+ if self.main_window.multi_view_models_dirty[viewer_index]:
76
+ # This model exists but doesn't have the current image loaded yet
77
+
78
+ # Check if any loading session is already in progress to avoid conflicts
79
+ any_loading = any(
80
+ self.main_window.multi_view_models_updating[i]
81
+ for i in range(len(self.main_window.multi_view_models_updating))
82
+ )
83
+ if any_loading:
84
+ self.main_window._show_notification(
85
+ "AI models are already loading, please wait...", duration=2000
86
+ )
87
+ return
88
+
89
+ self.main_window._show_notification(
90
+ f"Loading image into AI model for viewer {viewer_index + 1}...",
91
+ duration=0,
92
+ )
93
+ # Start sequential loading but only show progress for models that need updating
94
+ dirty_count = sum(
95
+ 1
96
+ for i in range(len(self.main_window.multi_view_models))
97
+ if self.main_window.multi_view_models_dirty[i]
98
+ and self.main_window.multi_view_images[i]
99
+ )
100
+ if dirty_count > 0:
101
+ # Initialize progress tracking for lazy loading
102
+ self.main_window._multi_view_loading_step = 0
103
+ self.main_window._multi_view_total_steps = dirty_count
104
+ self.main_window._start_sequential_multi_view_sam_loading()
105
+ return # Exit early, let image load in background
106
+
107
+ # Skip AI prediction if model is not ready
108
+ if self.main_window.multi_view_models[viewer_index] is None:
109
+ logger.error(f"AI model not initialized for viewer {viewer_index + 1}")
110
+ self.main_window._show_warning_notification(
111
+ f"AI model not initialized for viewer {viewer_index + 1}"
112
+ )
113
+ return
114
+
115
+ # Determine if positive or negative click
116
+ positive = event.button() == Qt.MouseButton.LeftButton
117
+
118
+ if positive:
119
+ # Left-click: Set up for potential drag (similar to single-view AI mode)
120
+ if not hasattr(self.main_window, "multi_view_ai_click_starts"):
121
+ num_viewers = self._get_num_viewers()
122
+ self.main_window.multi_view_ai_click_starts = [None] * num_viewers
123
+ if not hasattr(self.main_window, "multi_view_ai_rects"):
124
+ num_viewers = self._get_num_viewers()
125
+ self.main_window.multi_view_ai_rects = [None] * num_viewers
126
+
127
+ self.main_window.multi_view_ai_click_starts[viewer_index] = pos
128
+ # We'll determine if it's a click or drag in the move/release handlers
129
+ return
130
+
131
+ # Right-click: Add negative point immediately
132
+ model = self.main_window.multi_view_models[viewer_index]
133
+ if model is None:
134
+ logger.error(f"Model not initialized for viewer {viewer_index}")
135
+ return
136
+ viewer = self.main_window.multi_view_viewers[viewer_index]
137
+
138
+ # Add visual point to the viewer
139
+ point_color = QColor(0, 255, 0) if positive else QColor(255, 0, 0)
140
+ point_diameter = self.main_window.point_radius * 2
141
+ point_item = QGraphicsEllipseItem(
142
+ pos.x() - self.main_window.point_radius,
143
+ pos.y() - self.main_window.point_radius,
144
+ point_diameter,
145
+ point_diameter,
146
+ )
147
+ point_item.setBrush(QBrush(point_color))
148
+ point_item.setPen(QPen(Qt.PenStyle.NoPen))
149
+ viewer.scene().addItem(point_item)
150
+
151
+ # Track point items for clearing
152
+ if not hasattr(self.main_window, "multi_view_point_items"):
153
+ num_viewers = self._get_num_viewers()
154
+ self.main_window.multi_view_point_items = {
155
+ i: [] for i in range(num_viewers)
156
+ }
157
+ self.main_window.multi_view_point_items[viewer_index].append(point_item)
158
+
159
+ # Record the action for undo
160
+ self.main_window.action_history.append(
161
+ {
162
+ "type": "add_point",
163
+ "point_type": "positive" if positive else "negative",
164
+ "point_coords": [int(pos.x()), int(pos.y())],
165
+ "point_item": point_item,
166
+ "viewer_mode": "multi",
167
+ "viewer_index": viewer_index,
168
+ }
169
+ )
170
+ # Clear redo history when a new action is performed
171
+ self.main_window.redo_history.clear()
172
+
173
+ # Process with SAM model
174
+ try:
175
+ # Convert position to model coordinates
176
+ model_pos = self.main_window._transform_multi_view_coords_to_sam_coords(
177
+ pos, viewer_index
178
+ )
179
+
180
+ # Initialize point accumulation for multiview mode (like single view)
181
+ if not hasattr(self.main_window, "multi_view_positive_points"):
182
+ num_viewers = self._get_num_viewers()
183
+ self.main_window.multi_view_positive_points = {
184
+ i: [] for i in range(num_viewers)
185
+ }
186
+ if not hasattr(self.main_window, "multi_view_negative_points"):
187
+ num_viewers = self._get_num_viewers()
188
+ self.main_window.multi_view_negative_points = {
189
+ i: [] for i in range(num_viewers)
190
+ }
191
+
192
+ # Add current point to accumulated lists
193
+ if positive:
194
+ self.main_window.multi_view_positive_points[viewer_index].append(
195
+ model_pos
196
+ )
197
+ else:
198
+ self.main_window.multi_view_negative_points[viewer_index].append(
199
+ model_pos
200
+ )
201
+
202
+ # Prepare points for prediction using ALL accumulated points (like single view mode)
203
+ positive_points = self.main_window.multi_view_positive_points[viewer_index]
204
+ negative_points = self.main_window.multi_view_negative_points[viewer_index]
205
+
206
+ # Generate mask using the specific model
207
+ result = model.predict(positive_points, negative_points)
208
+
209
+ if result is not None and len(result) == 3:
210
+ # Unpack the tuple like single view mode
211
+ mask, scores, logits = result
212
+
213
+ # Ensure mask is boolean (SAM models can return float masks)
214
+ if mask.dtype != bool:
215
+ mask = mask > 0.5
216
+
217
+ # Store prediction data for potential saving
218
+ if not hasattr(self.main_window, "multi_view_ai_predictions"):
219
+ self.main_window.multi_view_ai_predictions = {}
220
+
221
+ # Store all accumulated points, not just current point
222
+ all_points = []
223
+ all_labels = []
224
+
225
+ # Add all positive points
226
+ for pt in positive_points:
227
+ all_points.append(pt)
228
+ all_labels.append(1)
229
+
230
+ # Add all negative points
231
+ for pt in negative_points:
232
+ all_points.append(pt)
233
+ all_labels.append(0)
234
+
235
+ self.main_window.multi_view_ai_predictions[viewer_index] = {
236
+ "mask": mask.astype(bool),
237
+ "points": all_points,
238
+ "labels": all_labels,
239
+ "model_pos": model_pos,
240
+ "positive": positive,
241
+ }
242
+
243
+ # Show preview mask
244
+ self._display_ai_preview(mask, viewer_index)
245
+
246
+ # Generate predictions for all other viewers with same coordinates
247
+ other_viewer_indices = self._get_other_viewer_indices(viewer_index)
248
+ for other_viewer_index in other_viewer_indices:
249
+ self._generate_paired_ai_preview(
250
+ viewer_index, other_viewer_index, pos, model_pos, positive
251
+ )
252
+
253
+ except Exception as e:
254
+ logger.error(f"Error processing AI click for viewer {viewer_index}: {e}")
255
+
256
+ def handle_polygon_click(self, pos, viewer_index=0):
257
+ """Handle polygon mode click in multi view."""
258
+ points = self.main_window.multi_view_polygon_points[viewer_index]
259
+
260
+ # Check if clicking near first point to close polygon
261
+ if points and len(points) > 2:
262
+ first_point = points[0]
263
+ distance_squared = (pos.x() - first_point.x()) ** 2 + (
264
+ pos.y() - first_point.y()
265
+ ) ** 2
266
+ if distance_squared < self.main_window.polygon_join_threshold**2:
267
+ self._finalize_multi_view_polygon(viewer_index)
268
+ return
269
+
270
+ # Add point to polygon
271
+ points.append(pos)
272
+
273
+ # Add visual point
274
+ viewer = self.main_window.multi_view_viewers[viewer_index]
275
+ point_diameter = self.main_window.point_radius * 2
276
+ point_item = QGraphicsEllipseItem(
277
+ pos.x() - self.main_window.point_radius,
278
+ pos.y() - self.main_window.point_radius,
279
+ point_diameter,
280
+ point_diameter,
281
+ )
282
+ point_item.setBrush(QBrush(QColor(0, 255, 255))) # Cyan like single view
283
+ point_item.setPen(QPen(Qt.PenStyle.NoPen))
284
+ viewer.scene().addItem(point_item)
285
+
286
+ # Store visual item for cleanup
287
+ self.main_window.multi_view_polygon_preview_items[viewer_index].append(
288
+ point_item
289
+ )
290
+
291
+ def handle_bbox_start(self, pos, viewer_index=0):
292
+ """Handle bbox mode start in multi view."""
293
+ # Initialize storage if needed
294
+ if not hasattr(self.main_window, "multi_view_bbox_starts"):
295
+ num_viewers = self._get_num_viewers()
296
+ self.main_window.multi_view_bbox_starts = [None] * num_viewers
297
+ if not hasattr(self.main_window, "multi_view_bbox_rects"):
298
+ num_viewers = self._get_num_viewers()
299
+ self.main_window.multi_view_bbox_rects = [None] * num_viewers
300
+
301
+ self.main_window.multi_view_bbox_starts[viewer_index] = pos
302
+
303
+ # Create rectangle for this viewer
304
+ rect_item = QGraphicsRectItem()
305
+ rect_item.setPen(QPen(QColor(255, 255, 0), 2)) # Yellow
306
+ self.main_window.multi_view_viewers[viewer_index].scene().addItem(rect_item)
307
+ self.main_window.multi_view_bbox_rects[viewer_index] = rect_item
308
+
309
+ def handle_bbox_drag(self, pos, viewer_index=0):
310
+ """Handle bbox mode drag in multi view."""
311
+ if (
312
+ hasattr(self.main_window, "multi_view_bbox_starts")
313
+ and hasattr(self.main_window, "multi_view_bbox_rects")
314
+ and self.main_window.multi_view_bbox_starts[viewer_index] is not None
315
+ and self.main_window.multi_view_bbox_rects[viewer_index] is not None
316
+ ):
317
+ from PyQt6.QtCore import QRectF
318
+
319
+ start_pos = self.main_window.multi_view_bbox_starts[viewer_index]
320
+ rect = QRectF(start_pos, pos).normalized()
321
+ self.main_window.multi_view_bbox_rects[viewer_index].setRect(rect)
322
+
323
+ def handle_bbox_complete(self, pos, viewer_index=0):
324
+ """Handle bbox mode completion in multi view."""
325
+ if not hasattr(self.main_window, "multi_view_bbox_starts") or not hasattr(
326
+ self.main_window, "multi_view_bbox_rects"
327
+ ):
328
+ return
329
+
330
+ if (
331
+ self.main_window.multi_view_bbox_starts[viewer_index] is None
332
+ or self.main_window.multi_view_bbox_rects[viewer_index] is None
333
+ ):
334
+ return
335
+
336
+ # Complete the bounding box
337
+ start_pos = self.main_window.multi_view_bbox_starts[viewer_index]
338
+ rect_item = self.main_window.multi_view_bbox_rects[viewer_index]
339
+
340
+ # Calculate final rectangle
341
+ x = min(start_pos.x(), pos.x())
342
+ y = min(start_pos.y(), pos.y())
343
+ width = abs(pos.x() - start_pos.x())
344
+ height = abs(pos.y() - start_pos.y())
345
+
346
+ # Remove temporary rectangle
347
+ self.main_window.multi_view_viewers[viewer_index].scene().removeItem(rect_item)
348
+
349
+ # Only create segment if minimum size is met (2x2 pixels)
350
+ if width < 2 or height < 2:
351
+ # Clean up and return without creating segment
352
+ self.main_window.multi_view_bbox_starts[viewer_index] = None
353
+ self.main_window.multi_view_bbox_rects[viewer_index] = None
354
+ return
355
+
356
+ # Create view-specific bbox data as polygon
357
+ view_data = {
358
+ "vertices": [
359
+ [x, y],
360
+ [x + width, y],
361
+ [x + width, y + height],
362
+ [x, y + height],
363
+ ],
364
+ "mask": None,
365
+ }
366
+
367
+ # Create segment with views structure for all viewers (like polygon mode)
368
+ num_viewers = self._get_num_viewers()
369
+
370
+ paired_segment = {"type": "Polygon", "views": {}}
371
+
372
+ # Add view data for all viewers with same coordinates
373
+ for viewer_idx in range(num_viewers):
374
+ paired_segment["views"][viewer_idx] = {
375
+ "vertices": view_data["vertices"].copy(),
376
+ "mask": None,
377
+ }
378
+
379
+ # Add to segment manager
380
+ self.main_window.segment_manager.add_segment(paired_segment)
381
+
382
+ # Record for undo
383
+ self.main_window.action_history.append(
384
+ {"type": "add_segment", "data": paired_segment}
385
+ )
386
+
387
+ # Clear redo history when a new action is performed
388
+ self.main_window.redo_history.clear()
389
+
390
+ # Clean up
391
+ self.main_window.multi_view_bbox_starts[viewer_index] = None
392
+ self.main_window.multi_view_bbox_rects[viewer_index] = None
393
+
394
+ def display_all_segments(self):
395
+ """Display all segments in multi view."""
396
+ # Clear existing segment items from all viewers
397
+ if hasattr(self.main_window, "multi_view_segment_items"):
398
+ for (
399
+ viewer_idx,
400
+ viewer_segments,
401
+ ) in self.main_window.multi_view_segment_items.items():
402
+ for _segment_idx, items in viewer_segments.items():
403
+ for item in items[
404
+ :
405
+ ]: # Create a copy to avoid modification during iteration
406
+ try:
407
+ if item.scene():
408
+ self.main_window.multi_view_viewers[
409
+ viewer_idx
410
+ ].scene().removeItem(item)
411
+ except RuntimeError:
412
+ # Object has been deleted, skip it
413
+ pass
414
+
415
+ # Initialize segment items tracking for multi-view
416
+ num_viewers = self._get_num_viewers()
417
+ self.main_window.multi_view_segment_items = {i: {} for i in range(num_viewers)}
418
+
419
+ # Display segments on each viewer
420
+ for i, segment in enumerate(self.segment_manager.segments):
421
+ class_id = segment.get("class_id")
422
+ base_color = self.main_window._get_color_for_class(class_id)
423
+
424
+ # Check if segment has view-specific data
425
+ if "views" in segment:
426
+ # New multi-view format
427
+ for viewer_idx in range(len(self.main_window.multi_view_viewers)):
428
+ if viewer_idx in segment["views"]:
429
+ self._display_segment_in_viewer(
430
+ i, segment, viewer_idx, base_color
431
+ )
432
+ else:
433
+ # Legacy single-view format - display in all viewers
434
+ for viewer_idx in range(len(self.main_window.multi_view_viewers)):
435
+ self._display_segment_in_viewer(i, segment, viewer_idx, base_color)
436
+
437
+ def clear_all_points(self):
438
+ """Clear all temporary points in multi view."""
439
+ # Clear multi-view polygon points
440
+ if hasattr(self.main_window, "multi_view_polygon_points"):
441
+ for i in range(len(self.main_window.multi_view_polygon_points)):
442
+ self._clear_multi_view_polygon(i)
443
+
444
+ # Clear AI prediction previews and points
445
+ self._clear_ai_previews()
446
+
447
+ def _add_multi_view_segment(self, segment_type, class_id, viewer_index, view_data):
448
+ """Add a segment with view-specific data to the multi-view system."""
449
+ # Delegate to main window's method to ensure consistent undo/redo handling
450
+ self.main_window._add_multi_view_segment(
451
+ segment_type, class_id, viewer_index, view_data
452
+ )
453
+
454
+ def _create_paired_ai_segment(
455
+ self, viewer_index, view_data, other_viewer_index, pos, positive
456
+ ):
457
+ """Create paired AI segments for both viewers with the same class ID."""
458
+ try:
459
+ # Check if the other viewer's model is ready
460
+ if (
461
+ other_viewer_index < len(self.main_window.multi_view_models)
462
+ and self.main_window.multi_view_models[other_viewer_index] is not None
463
+ and not self.main_window.multi_view_models_dirty[other_viewer_index]
464
+ and not self.main_window.multi_view_models_updating[other_viewer_index]
465
+ ):
466
+ # Run AI prediction on the other viewer
467
+ other_model = self.main_window.multi_view_models[other_viewer_index]
468
+
469
+ # Convert position to model coordinates for the other viewer
470
+ other_model_pos = (
471
+ self.main_window._transform_multi_view_coords_to_sam_coords(
472
+ pos, other_viewer_index
473
+ )
474
+ )
475
+
476
+ # Prepare points for prediction
477
+ if positive:
478
+ positive_points = [other_model_pos]
479
+ negative_points = []
480
+ else:
481
+ positive_points = []
482
+ negative_points = [other_model_pos]
483
+
484
+ # Generate mask using the other model
485
+ other_result = other_model.predict(positive_points, negative_points)
486
+
487
+ if other_result is not None and len(other_result) == 3:
488
+ other_mask, other_scores, other_logits = other_result
489
+
490
+ # Ensure mask is boolean
491
+ if other_mask.dtype != bool:
492
+ other_mask = other_mask > 0.5
493
+
494
+ # Create view data for the other viewer
495
+ other_view_data = {
496
+ "mask": other_mask.astype(bool),
497
+ "points": [(pos.x(), pos.y())],
498
+ "labels": [1 if positive else 0],
499
+ }
500
+
501
+ # Create paired segment with both view data
502
+ paired_segment = {
503
+ "type": "AI",
504
+ "views": {
505
+ viewer_index: view_data,
506
+ other_viewer_index: other_view_data,
507
+ },
508
+ }
509
+
510
+ # Add to main segment manager (this will assign the same class ID)
511
+ self.main_window.segment_manager.add_segment(paired_segment)
512
+
513
+ # Record for undo
514
+ self.main_window.action_history.append(
515
+ {"type": "add_segment", "data": paired_segment}
516
+ )
517
+
518
+ # Update UI lists to show the new segment
519
+ self.main_window._update_all_lists()
520
+ return
521
+
522
+ # If we can't create paired segment, fall back to single segment
523
+ self._add_multi_view_segment("AI", None, viewer_index, view_data)
524
+
525
+ except Exception as e:
526
+ logger.error(f"Error creating paired AI segment: {e}")
527
+ # Fall back to single segment
528
+ self._add_multi_view_segment("AI", None, viewer_index, view_data)
529
+
530
+ def _display_ai_preview(self, mask, viewer_index):
531
+ """Display AI prediction preview for a specific viewer."""
532
+ if viewer_index >= len(self.main_window.multi_view_viewers):
533
+ return
534
+
535
+ viewer = self.main_window.multi_view_viewers[viewer_index]
536
+
537
+ # Clear existing preview for this viewer
538
+ if not hasattr(self.main_window, "multi_view_preview_items"):
539
+ self.main_window.multi_view_preview_items = {}
540
+ if (
541
+ viewer_index in self.main_window.multi_view_preview_items
542
+ and self.main_window.multi_view_preview_items[viewer_index].scene()
543
+ ):
544
+ viewer.scene().removeItem(
545
+ self.main_window.multi_view_preview_items[viewer_index]
546
+ )
547
+
548
+ # Create preview mask
549
+ pixmap = mask_to_pixmap(mask, (255, 255, 0)) # Yellow preview
550
+ preview_item = viewer.scene().addPixmap(pixmap)
551
+ preview_item.setZValue(50)
552
+ self.main_window.multi_view_preview_items[viewer_index] = preview_item
553
+
554
+ def _generate_paired_ai_preview(
555
+ self, source_viewer_index, target_viewer_index, pos, model_pos, positive
556
+ ):
557
+ """Generate AI prediction preview for the paired viewer using same model coordinates."""
558
+ try:
559
+ # Check if the target viewer's model is ready
560
+ if (
561
+ target_viewer_index < len(self.main_window.multi_view_models)
562
+ and self.main_window.multi_view_models[target_viewer_index] is not None
563
+ and not self.main_window.multi_view_models_dirty[target_viewer_index]
564
+ and not self.main_window.multi_view_models_updating[target_viewer_index]
565
+ ):
566
+ # Run AI prediction on the target viewer
567
+ target_model = self.main_window.multi_view_models[target_viewer_index]
568
+
569
+ # Use the same model coordinates (no transformation needed)
570
+ target_model_pos = model_pos
571
+
572
+ # Prepare points for prediction
573
+ if positive:
574
+ positive_points = [target_model_pos]
575
+ negative_points = []
576
+ else:
577
+ positive_points = []
578
+ negative_points = [target_model_pos]
579
+
580
+ # Generate mask using the target model
581
+ result = target_model.predict(positive_points, negative_points)
582
+
583
+ if result is not None and len(result) == 3:
584
+ mask, scores, logits = result
585
+
586
+ # Ensure mask is boolean
587
+ if mask.dtype != bool:
588
+ mask = mask > 0.5
589
+
590
+ # Store prediction data
591
+ if not hasattr(self.main_window, "multi_view_ai_predictions"):
592
+ self.main_window.multi_view_ai_predictions = {}
593
+
594
+ self.main_window.multi_view_ai_predictions[target_viewer_index] = {
595
+ "mask": mask.astype(bool),
596
+ "points": [(pos.x(), pos.y())],
597
+ "labels": [1 if positive else 0],
598
+ "model_pos": target_model_pos,
599
+ "positive": positive,
600
+ }
601
+
602
+ # Show preview
603
+ self._display_ai_preview(mask, target_viewer_index)
604
+
605
+ except Exception as e:
606
+ logger.error(f"Error generating paired AI preview: {e}")
607
+
608
+ def _generate_paired_ai_bbox_preview(
609
+ self, source_viewer_index, target_viewer_index, box
610
+ ):
611
+ """Generate AI bounding box prediction preview for the paired viewer using same box coordinates."""
612
+ try:
613
+ # Check if the target viewer's model is ready
614
+ if (
615
+ target_viewer_index < len(self.main_window.multi_view_models)
616
+ and self.main_window.multi_view_models[target_viewer_index] is not None
617
+ and not self.main_window.multi_view_models_dirty[target_viewer_index]
618
+ and not self.main_window.multi_view_models_updating[target_viewer_index]
619
+ ):
620
+ # Run AI prediction on the target viewer
621
+ target_model = self.main_window.multi_view_models[target_viewer_index]
622
+
623
+ # Use the same bounding box coordinates
624
+ result = target_model.predict_from_box(box)
625
+
626
+ if result is not None and len(result) == 3:
627
+ mask, scores, logits = result
628
+
629
+ # Ensure mask is boolean
630
+ if mask.dtype != bool:
631
+ mask = mask > 0.5
632
+
633
+ # Store prediction data
634
+ if not hasattr(self.main_window, "multi_view_ai_predictions"):
635
+ self.main_window.multi_view_ai_predictions = {}
636
+
637
+ self.main_window.multi_view_ai_predictions[target_viewer_index] = {
638
+ "mask": mask.astype(bool),
639
+ "box": box,
640
+ "points": [], # Empty for box predictions
641
+ "labels": [], # Empty for box predictions
642
+ }
643
+
644
+ # Show preview
645
+ self._display_ai_preview(mask, target_viewer_index)
646
+
647
+ except Exception as e:
648
+ logger.error(f"Error generating paired AI bbox preview: {e}")
649
+
650
+ def _clear_ai_previews(self):
651
+ """Clear AI prediction previews and points from all viewers."""
652
+ # Clear preview masks
653
+ if hasattr(self.main_window, "multi_view_preview_items"):
654
+ for (
655
+ viewer_index,
656
+ preview_item,
657
+ ) in self.main_window.multi_view_preview_items.items():
658
+ if preview_item and preview_item.scene():
659
+ self.main_window.multi_view_viewers[
660
+ viewer_index
661
+ ].scene().removeItem(preview_item)
662
+ self.main_window.multi_view_preview_items.clear()
663
+
664
+ # Clear prediction data
665
+ if hasattr(self.main_window, "multi_view_ai_predictions"):
666
+ self.main_window.multi_view_ai_predictions.clear()
667
+
668
+ # Clear tracked point items
669
+ if hasattr(self.main_window, "multi_view_point_items"):
670
+ for (
671
+ viewer_index,
672
+ point_items,
673
+ ) in self.main_window.multi_view_point_items.items():
674
+ for point_item in point_items:
675
+ if point_item.scene():
676
+ self.main_window.multi_view_viewers[
677
+ viewer_index
678
+ ].scene().removeItem(point_item)
679
+ point_items.clear()
680
+
681
+ # Clear accumulated point lists (like single view mode)
682
+ if hasattr(self.main_window, "multi_view_positive_points"):
683
+ for viewer_points in self.main_window.multi_view_positive_points.values():
684
+ viewer_points.clear()
685
+ if hasattr(self.main_window, "multi_view_negative_points"):
686
+ for viewer_points in self.main_window.multi_view_negative_points.values():
687
+ viewer_points.clear()
688
+
689
+ def save_ai_predictions(self):
690
+ """Save AI predictions as actual segments."""
691
+ if not hasattr(self.main_window, "multi_view_ai_predictions"):
692
+ return
693
+
694
+ predictions = self.main_window.multi_view_ai_predictions
695
+ if len(predictions) == 0:
696
+ return
697
+
698
+ # Create paired segments with views structure for multi-view mode
699
+ num_viewers = self._get_num_viewers()
700
+ if len(predictions) >= 2:
701
+ # Multiple viewers have predictions - create paired segment with views structure
702
+ # Get the current active class or determine next class ID
703
+ active_class = self.main_window.segment_manager.get_active_class()
704
+ if active_class is None:
705
+ # Determine next class ID
706
+ existing_classes = (
707
+ self.main_window.segment_manager.get_unique_class_ids()
708
+ )
709
+ next_class_id = max(existing_classes) + 1 if existing_classes else 1
710
+ else:
711
+ next_class_id = active_class
712
+
713
+ # Create paired segment with both viewer data
714
+ paired_segment = {"type": "AI", "class_id": next_class_id, "views": {}}
715
+
716
+ # Add view data for each viewer
717
+ for viewer_index in range(num_viewers):
718
+ if viewer_index in predictions:
719
+ view_data = {"mask": predictions[viewer_index]["mask"]}
720
+ # Add points/labels if they exist (point-based prediction)
721
+ if "points" in predictions[viewer_index]:
722
+ view_data["points"] = predictions[viewer_index]["points"]
723
+ view_data["labels"] = predictions[viewer_index]["labels"]
724
+ # Add box if it exists (box-based prediction)
725
+ if "box" in predictions[viewer_index]:
726
+ view_data["box"] = predictions[viewer_index]["box"]
727
+
728
+ paired_segment["views"][viewer_index] = view_data
729
+
730
+ # Add to main segment manager
731
+ self.main_window.segment_manager.add_segment(paired_segment)
732
+
733
+ # Record for undo
734
+ self.main_window.action_history.append(
735
+ {"type": "add_segment", "data": paired_segment}
736
+ )
737
+
738
+ self.main_window._update_all_lists()
739
+
740
+ else:
741
+ # Only one viewer has prediction - create single segment with views structure
742
+ for viewer_index, prediction in predictions.items():
743
+ view_data = {"mask": prediction["mask"]}
744
+ # Add points/labels if they exist (point-based prediction)
745
+ if "points" in prediction:
746
+ view_data["points"] = prediction["points"]
747
+ view_data["labels"] = prediction["labels"]
748
+ # Add box if it exists (box-based prediction)
749
+ if "box" in prediction:
750
+ view_data["box"] = prediction["box"]
751
+
752
+ segment_data = {"type": "AI", "views": {viewer_index: view_data}}
753
+ self.main_window.segment_manager.add_segment(segment_data)
754
+
755
+ # Record for undo
756
+ self.main_window.action_history.append(
757
+ {"type": "add_segment", "data": segment_data}
758
+ )
759
+
760
+ self.main_window._update_all_lists()
761
+
762
+ # Clear previews after saving
763
+ self._clear_ai_previews()
764
+
765
+ def _finalize_multi_view_polygon(self, viewer_index):
766
+ """Finalize polygon drawing for a specific viewer."""
767
+ points = self.main_window.multi_view_polygon_points[viewer_index]
768
+ if len(points) < 3:
769
+ return
770
+
771
+ # Create view-specific polygon data
772
+ view_data = {
773
+ "vertices": [[p.x(), p.y()] for p in points],
774
+ "mask": None,
775
+ }
776
+
777
+ # Mirror the polygon to all other viewers automatically
778
+ num_viewers = self._get_num_viewers()
779
+
780
+ # Create segment with views structure for all viewers
781
+ paired_segment = {"type": "Polygon", "views": {}}
782
+
783
+ # Add the current viewer's data
784
+ paired_segment["views"][viewer_index] = view_data
785
+
786
+ # Mirror to all other viewers with same coordinates (they should align between linked images)
787
+ for other_viewer_index in range(num_viewers):
788
+ if other_viewer_index != viewer_index:
789
+ mirrored_view_data = {
790
+ "vertices": view_data[
791
+ "vertices"
792
+ ].copy(), # Use same coordinates for mirrored polygon
793
+ "mask": None,
794
+ }
795
+ paired_segment["views"][other_viewer_index] = mirrored_view_data
796
+
797
+ # Add to segment manager
798
+ self.main_window.segment_manager.add_segment(paired_segment)
799
+
800
+ # Record for undo
801
+ self.main_window.action_history.append(
802
+ {"type": "add_segment", "data": paired_segment}
803
+ )
804
+
805
+ # Clear redo history when a new action is performed
806
+ self.main_window.redo_history.clear()
807
+
808
+ # Update UI
809
+ self.main_window._update_all_lists()
810
+ viewer_count_text = "all viewers" if num_viewers > 2 else "both viewers"
811
+ self.main_window._show_notification(
812
+ f"Polygon created and mirrored to {viewer_count_text}."
813
+ )
814
+
815
+ # Clear polygon state for this viewer
816
+ self._clear_multi_view_polygon(viewer_index)
817
+
818
+ def _clear_multi_view_polygon(self, viewer_index):
819
+ """Clear polygon state for a specific viewer."""
820
+ # Clear points
821
+ if hasattr(
822
+ self.main_window, "multi_view_polygon_points"
823
+ ) and viewer_index < len(self.main_window.multi_view_polygon_points):
824
+ self.main_window.multi_view_polygon_points[viewer_index].clear()
825
+
826
+ # Remove all visual items
827
+ if (
828
+ hasattr(self.main_window, "multi_view_viewers")
829
+ and viewer_index < len(self.main_window.multi_view_viewers)
830
+ and hasattr(self.main_window, "multi_view_polygon_preview_items")
831
+ and viewer_index < len(self.main_window.multi_view_polygon_preview_items)
832
+ ):
833
+ viewer = self.main_window.multi_view_viewers[viewer_index]
834
+ for item in self.main_window.multi_view_polygon_preview_items[viewer_index]:
835
+ if item.scene():
836
+ viewer.scene().removeItem(item)
837
+ self.main_window.multi_view_polygon_preview_items[viewer_index].clear()
838
+
839
+ def _display_segment_in_viewer(
840
+ self, segment_index, segment, viewer_index, base_color
841
+ ):
842
+ """Display a specific segment in a specific viewer."""
843
+ if viewer_index >= len(self.main_window.multi_view_viewers):
844
+ return
845
+
846
+ viewer = self.main_window.multi_view_viewers[viewer_index]
847
+
848
+ # Initialize segment items for this viewer if needed
849
+ if segment_index not in self.main_window.multi_view_segment_items[viewer_index]:
850
+ self.main_window.multi_view_segment_items[viewer_index][segment_index] = []
851
+
852
+ # Get segment data (either from views or direct)
853
+ if "views" in segment and viewer_index in segment["views"]:
854
+ segment_data = segment["views"][viewer_index]
855
+ segment_type = segment["type"]
856
+ else:
857
+ segment_data = segment
858
+ segment_type = segment["type"]
859
+
860
+ # Display based on type
861
+ if segment_type == "Polygon" and segment_data.get("vertices"):
862
+ # Display polygon
863
+ qpoints = [QPointF(p[0], p[1]) for p in segment_data["vertices"]]
864
+ poly_item = HoverablePolygonItem(QPolygonF(qpoints))
865
+
866
+ default_brush = QBrush(
867
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 70)
868
+ )
869
+ hover_brush = QBrush(
870
+ QColor(base_color.red(), base_color.green(), base_color.blue(), 170)
871
+ )
872
+ poly_item.set_brushes(default_brush, hover_brush)
873
+ poly_item.set_segment_info(segment_index, self.main_window)
874
+ poly_item.setPen(QPen(Qt.GlobalColor.transparent))
875
+
876
+ logger.debug(
877
+ f"Created HoverablePolygonItem for segment {segment_index} in viewer {viewer_index}"
878
+ )
879
+
880
+ viewer.scene().addItem(poly_item)
881
+ self.main_window.multi_view_segment_items[viewer_index][
882
+ segment_index
883
+ ].append(poly_item)
884
+
885
+ elif segment_type == "AI" and segment_data.get("mask") is not None:
886
+ # Display AI mask
887
+ default_pixmap = mask_to_pixmap(
888
+ segment_data["mask"], base_color.getRgb()[:3], alpha=70
889
+ )
890
+ hover_pixmap = mask_to_pixmap(
891
+ segment_data["mask"], base_color.getRgb()[:3], alpha=170
892
+ )
893
+ pixmap_item = HoverablePixmapItem()
894
+ pixmap_item.set_pixmaps(default_pixmap, hover_pixmap)
895
+ pixmap_item.set_segment_info(segment_index, self.main_window)
896
+
897
+ logger.debug(
898
+ f"Created HoverablePixmapItem for segment {segment_index} in viewer {viewer_index}"
899
+ )
900
+
901
+ viewer.scene().addItem(pixmap_item)
902
+ pixmap_item.setZValue(segment_index + 1)
903
+ self.main_window.multi_view_segment_items[viewer_index][
904
+ segment_index
905
+ ].append(pixmap_item)
906
+
907
+ def handle_ai_drag(self, pos, viewer_index=0):
908
+ """Handle AI mode drag in multi view."""
909
+ if (
910
+ not hasattr(self.main_window, "multi_view_ai_click_starts")
911
+ or not hasattr(self.main_window, "multi_view_ai_rects")
912
+ or self.main_window.multi_view_ai_click_starts[viewer_index] is None
913
+ ):
914
+ return
915
+
916
+ start_pos = self.main_window.multi_view_ai_click_starts[viewer_index]
917
+
918
+ # Check if we've moved enough to consider this a drag
919
+ drag_distance = (
920
+ (pos.x() - start_pos.x()) ** 2 + (pos.y() - start_pos.y()) ** 2
921
+ ) ** 0.5
922
+
923
+ if drag_distance > 5: # Minimum drag distance
924
+ viewer = self.main_window.multi_view_viewers[viewer_index]
925
+
926
+ # Create rubber band if not exists
927
+ if self.main_window.multi_view_ai_rects[viewer_index] is None:
928
+ from PyQt6.QtCore import Qt
929
+ from PyQt6.QtGui import QPen
930
+ from PyQt6.QtWidgets import QGraphicsRectItem
931
+
932
+ rect_item = QGraphicsRectItem()
933
+ rect_item.setPen(QPen(Qt.GlobalColor.cyan, 2, Qt.PenStyle.DashLine))
934
+ viewer.scene().addItem(rect_item)
935
+ self.main_window.multi_view_ai_rects[viewer_index] = rect_item
936
+
937
+ # Update rubber band
938
+ from PyQt6.QtCore import QRectF
939
+
940
+ rect = QRectF(start_pos, pos).normalized()
941
+ self.main_window.multi_view_ai_rects[viewer_index].setRect(rect)
942
+
943
+ def handle_ai_complete(self, pos, viewer_index=0):
944
+ """Handle AI mode completion in multi view."""
945
+ if (
946
+ not hasattr(self.main_window, "multi_view_ai_click_starts")
947
+ or self.main_window.multi_view_ai_click_starts[viewer_index] is None
948
+ ):
949
+ return
950
+
951
+ start_pos = self.main_window.multi_view_ai_click_starts[viewer_index]
952
+
953
+ # Calculate drag distance
954
+ drag_distance = (
955
+ (pos.x() - start_pos.x()) ** 2 + (pos.y() - start_pos.y()) ** 2
956
+ ) ** 0.5
957
+
958
+ if (
959
+ hasattr(self.main_window, "multi_view_ai_rects")
960
+ and self.main_window.multi_view_ai_rects[viewer_index] is not None
961
+ and drag_distance > 5
962
+ ):
963
+ # This was a drag - use AI bounding box prediction
964
+ rect_item = self.main_window.multi_view_ai_rects[viewer_index]
965
+ rect = rect_item.rect()
966
+
967
+ # Remove rubber band
968
+ viewer = self.main_window.multi_view_viewers[viewer_index]
969
+ viewer.scene().removeItem(rect_item)
970
+ self.main_window.multi_view_ai_rects[viewer_index] = None
971
+ self.main_window.multi_view_ai_click_starts[viewer_index] = None
972
+
973
+ if rect.width() > 10 and rect.height() > 10: # Minimum box size
974
+ self._handle_multi_view_ai_bounding_box(rect, viewer_index)
975
+ else:
976
+ # This was a click - add positive point
977
+ self.main_window.multi_view_ai_click_starts[viewer_index] = None
978
+ if (
979
+ hasattr(self.main_window, "multi_view_ai_rects")
980
+ and self.main_window.multi_view_ai_rects[viewer_index] is not None
981
+ ):
982
+ viewer = self.main_window.multi_view_viewers[viewer_index]
983
+ viewer.scene().removeItem(
984
+ self.main_window.multi_view_ai_rects[viewer_index]
985
+ )
986
+ self.main_window.multi_view_ai_rects[viewer_index] = None
987
+
988
+ # Add positive point
989
+ self._handle_multi_view_ai_click_point(pos, viewer_index, positive=True)
990
+
991
+ def _handle_multi_view_ai_bounding_box(self, rect, viewer_index):
992
+ """Handle AI bounding box prediction for a specific viewer in multi-view mode."""
993
+ # Similar to single-view _handle_ai_bounding_box but for specific viewer
994
+ if viewer_index >= len(self.main_window.multi_view_models):
995
+ return
996
+
997
+ model = self.main_window.multi_view_models[viewer_index]
998
+ if model is None:
999
+ logger.error(f"Model not initialized for viewer {viewer_index}")
1000
+ return
1001
+
1002
+ try:
1003
+ # Convert QRectF to SAM box format [x1, y1, x2, y2]
1004
+ # from PyQt6.QtCore import QPointF
1005
+ # top_left = QPointF(rect.left(), rect.top())
1006
+ # bottom_right = QPointF(rect.right(), rect.bottom())
1007
+
1008
+ # For multi-view, we need to transform coordinates to model space
1009
+ # This would need the coordinate transformation logic
1010
+ box = [rect.left(), rect.top(), rect.right(), rect.bottom()]
1011
+
1012
+ # Generate mask using bounding box
1013
+ result = model.predict_from_box(box)
1014
+
1015
+ if result is not None and len(result) == 3:
1016
+ mask, scores, logits = result
1017
+
1018
+ # Ensure mask is boolean
1019
+ if mask.dtype != bool:
1020
+ mask = mask > 0.5
1021
+
1022
+ # Store prediction data for potential saving
1023
+ if not hasattr(self.main_window, "multi_view_ai_predictions"):
1024
+ self.main_window.multi_view_ai_predictions = {}
1025
+
1026
+ self.main_window.multi_view_ai_predictions[viewer_index] = {
1027
+ "mask": mask.astype(bool),
1028
+ "box": box,
1029
+ "points": [], # Empty for box predictions
1030
+ "labels": [], # Empty for box predictions
1031
+ }
1032
+
1033
+ # Show preview mask
1034
+ self._display_ai_preview(mask, viewer_index)
1035
+
1036
+ # Generate predictions for all other viewers with same bounding box
1037
+ other_viewer_indices = self._get_other_viewer_indices(viewer_index)
1038
+ for other_viewer_index in other_viewer_indices:
1039
+ self._generate_paired_ai_bbox_preview(
1040
+ viewer_index, other_viewer_index, box
1041
+ )
1042
+
1043
+ except Exception as e:
1044
+ logger.error(
1045
+ f"Error processing AI bounding box for viewer {viewer_index}: {e}"
1046
+ )
1047
+
1048
+ def _handle_multi_view_ai_click_point(self, pos, viewer_index, positive=True):
1049
+ """Handle AI point click for a specific viewer (extracted from handle_ai_click)."""
1050
+ model = self.main_window.multi_view_models[viewer_index]
1051
+ if model is None:
1052
+ logger.error(f"Model not initialized for viewer {viewer_index}")
1053
+ return
1054
+ viewer = self.main_window.multi_view_viewers[viewer_index]
1055
+
1056
+ # Add visual point to the viewer
1057
+ point_color = QColor(0, 255, 0) if positive else QColor(255, 0, 0)
1058
+ point_diameter = self.main_window.point_radius * 2
1059
+ point_item = QGraphicsEllipseItem(
1060
+ pos.x() - self.main_window.point_radius,
1061
+ pos.y() - self.main_window.point_radius,
1062
+ point_diameter,
1063
+ point_diameter,
1064
+ )
1065
+ point_item.setBrush(QBrush(point_color))
1066
+ point_item.setPen(QPen(Qt.PenStyle.NoPen))
1067
+ viewer.scene().addItem(point_item)
1068
+
1069
+ # Track point items for clearing
1070
+ if not hasattr(self.main_window, "multi_view_point_items"):
1071
+ # Initialize with dynamic viewer count
1072
+ num_viewers = self._get_num_viewers()
1073
+ self.main_window.multi_view_point_items = {
1074
+ i: [] for i in range(num_viewers)
1075
+ }
1076
+ self.main_window.multi_view_point_items[viewer_index].append(point_item)
1077
+
1078
+ # Record the action for undo
1079
+ self.main_window.action_history.append(
1080
+ {
1081
+ "type": "add_point",
1082
+ "point_type": "positive" if positive else "negative",
1083
+ "point_coords": [int(pos.x()), int(pos.y())],
1084
+ "point_item": point_item,
1085
+ "viewer_mode": "multi",
1086
+ "viewer_index": viewer_index,
1087
+ }
1088
+ )
1089
+ # Clear redo history when a new action is performed
1090
+ self.main_window.redo_history.clear()
1091
+
1092
+ # Process with SAM model
1093
+ try:
1094
+ # Convert position to model coordinates
1095
+ model_pos = self.main_window._transform_multi_view_coords_to_sam_coords(
1096
+ pos, viewer_index
1097
+ )
1098
+
1099
+ # Initialize point accumulation for multiview mode (like single view)
1100
+ if not hasattr(self.main_window, "multi_view_positive_points"):
1101
+ num_viewers = self._get_num_viewers()
1102
+ self.main_window.multi_view_positive_points = {
1103
+ i: [] for i in range(num_viewers)
1104
+ }
1105
+ if not hasattr(self.main_window, "multi_view_negative_points"):
1106
+ num_viewers = self._get_num_viewers()
1107
+ self.main_window.multi_view_negative_points = {
1108
+ i: [] for i in range(num_viewers)
1109
+ }
1110
+
1111
+ # Add current point to accumulated lists
1112
+ if positive:
1113
+ self.main_window.multi_view_positive_points[viewer_index].append(
1114
+ model_pos
1115
+ )
1116
+ else:
1117
+ self.main_window.multi_view_negative_points[viewer_index].append(
1118
+ model_pos
1119
+ )
1120
+
1121
+ # Prepare points for prediction using ALL accumulated points (like single view mode)
1122
+ positive_points = self.main_window.multi_view_positive_points[viewer_index]
1123
+ negative_points = self.main_window.multi_view_negative_points[viewer_index]
1124
+
1125
+ # Generate mask using the specific model
1126
+ result = model.predict(positive_points, negative_points)
1127
+
1128
+ if result is not None and len(result) == 3:
1129
+ # Unpack the tuple like single view mode
1130
+ mask, scores, logits = result
1131
+
1132
+ # Ensure mask is boolean (SAM models can return float masks)
1133
+ if mask.dtype != bool:
1134
+ mask = mask > 0.5
1135
+
1136
+ # Store prediction data for potential saving
1137
+ if not hasattr(self.main_window, "multi_view_ai_predictions"):
1138
+ self.main_window.multi_view_ai_predictions = {}
1139
+
1140
+ # Store all accumulated points, not just current point
1141
+ all_points = []
1142
+ all_labels = []
1143
+
1144
+ # Add all positive points
1145
+ for pt in positive_points:
1146
+ all_points.append(pt)
1147
+ all_labels.append(1)
1148
+
1149
+ # Add all negative points
1150
+ for pt in negative_points:
1151
+ all_points.append(pt)
1152
+ all_labels.append(0)
1153
+
1154
+ self.main_window.multi_view_ai_predictions[viewer_index] = {
1155
+ "mask": mask.astype(bool),
1156
+ "points": all_points,
1157
+ "labels": all_labels,
1158
+ "model_pos": model_pos,
1159
+ "positive": positive,
1160
+ }
1161
+
1162
+ # Show preview mask
1163
+ self._display_ai_preview(mask, viewer_index)
1164
+
1165
+ # Generate predictions for all other viewers with same coordinates
1166
+ other_viewer_indices = self._get_other_viewer_indices(viewer_index)
1167
+ for other_viewer_index in other_viewer_indices:
1168
+ self._generate_paired_ai_preview(
1169
+ viewer_index, other_viewer_index, pos, model_pos, positive
1170
+ )
1171
+
1172
+ except Exception as e:
1173
+ logger.error(f"Error processing AI click for viewer {viewer_index}: {e}")