singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,311 @@
1
+ """Export video with spatial attention heatmap overlay."""
2
+
3
+ import logging
4
+ import os
5
+ import math
6
+ import cv2
7
+ import numpy as np
8
+ from PyQt6.QtWidgets import (
9
+ QFileDialog, QMessageBox, QProgressDialog,
10
+ )
11
+ from PyQt6.QtCore import Qt
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _crop_frame_to_roi(frame_bgr, bbox_norm, out_w, out_h):
17
+ """Crop a frame to normalized xyxy ROI and resize to output size."""
18
+ if bbox_norm is None:
19
+ return cv2.resize(frame_bgr, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
20
+
21
+ h, w = frame_bgr.shape[:2]
22
+ try:
23
+ x1, y1, x2, y2 = [float(v) for v in bbox_norm]
24
+ except Exception:
25
+ return cv2.resize(frame_bgr, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
26
+
27
+ fx1 = int(round(x1 * w))
28
+ fy1 = int(round(y1 * h))
29
+ fx2 = int(round(x2 * w))
30
+ fy2 = int(round(y2 * h))
31
+ fx1 = max(0, min(fx1, w - 1))
32
+ fy1 = max(0, min(fy1, h - 1))
33
+ fx2 = max(fx1 + 1, min(fx2, w))
34
+ fy2 = max(fy1 + 1, min(fy2, h))
35
+ crop = frame_bgr[fy1:fy2, fx1:fx2]
36
+ if crop.size == 0:
37
+ crop = frame_bgr
38
+ return cv2.resize(crop, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
39
+
40
+
41
+ def _get_final_label_for_frame(widget, frame_idx, clip_idx, clip_starts, clip_length, frame_interval, classes, ignore_label):
42
+ """Match overlay export: use precise postprocessed label when available."""
43
+ use_precise = bool(
44
+ hasattr(widget, "frame_aggregation_check")
45
+ and widget.frame_aggregation_check.isChecked()
46
+ and getattr(widget, "aggregated_segments", None)
47
+ )
48
+
49
+ if use_precise and getattr(widget, "_use_ovr", False) and isinstance(
50
+ getattr(widget, "_aggregated_frame_scores_norm", None), np.ndarray
51
+ ):
52
+ active_infos = widget._get_precise_active_for_frame(frame_idx)
53
+ if active_infos:
54
+ pred_idx, _ = active_infos[0]
55
+ if pred_idx < 0:
56
+ return ignore_label, (120, 120, 120)
57
+ if 0 <= pred_idx < len(classes):
58
+ return classes[pred_idx], (255, 255, 255)
59
+ return f"class_{pred_idx}", (255, 255, 255)
60
+
61
+ if use_precise:
62
+ segments = getattr(widget, "aggregated_segments", None) or []
63
+ for seg in segments:
64
+ s0 = int(seg.get("start", 0))
65
+ s1 = int(seg.get("end", s0))
66
+ if s0 <= frame_idx <= s1:
67
+ pred_idx = int(seg.get("class", -1))
68
+ if pred_idx < 0:
69
+ return ignore_label, (120, 120, 120)
70
+ if 0 <= pred_idx < len(classes):
71
+ return classes[pred_idx], (255, 255, 255)
72
+ return f"class_{pred_idx}", (255, 255, 255)
73
+
74
+ if clip_idx is not None and clip_starts:
75
+ if hasattr(widget, "_effective_prediction_for_clip"):
76
+ pred_idx = int(widget._effective_prediction_for_clip(clip_idx))
77
+ else:
78
+ pred_idx = None
79
+ if hasattr(widget, "_effective_predictions") and hasattr(widget, "predictions") and widget.predictions:
80
+ effective_preds = widget._effective_predictions()
81
+ if clip_idx < len(effective_preds):
82
+ pred_idx = int(effective_preds[clip_idx])
83
+ if pred_idx is not None:
84
+ if pred_idx < 0:
85
+ return ignore_label, (120, 120, 120)
86
+ if 0 <= pred_idx < len(classes):
87
+ return classes[pred_idx], (255, 255, 255)
88
+ return f"class_{pred_idx}", (255, 255, 255)
89
+
90
+ return None, None
91
+
92
+
93
+ def export_attention_heatmap_video(widget):
94
+ """Generate and save a video with attention heatmaps overlaid on frames.
95
+
96
+ Uses effective predictions (with ignore threshold + manual corrections)
97
+ for labels. Applies temporal smoothing and Gaussian blur for smooth
98
+ heatmaps that aren't jumpy or blocky.
99
+ """
100
+ if not hasattr(widget, 'results_cache') or not widget.results_cache:
101
+ QMessageBox.warning(widget, "No results", "Run inference with 'Collect attention maps' enabled first.")
102
+ return
103
+
104
+ video_path = None
105
+ if hasattr(widget, 'filter_video_combo') and widget.filter_video_combo.currentData():
106
+ video_path = widget.filter_video_combo.currentData()
107
+ elif widget.video_path:
108
+ video_path = widget.video_path
109
+
110
+ if not video_path or video_path not in widget.results_cache:
111
+ QMessageBox.warning(widget, "No video", "Select a video with inference results.")
112
+ return
113
+
114
+ res = widget.results_cache[video_path]
115
+ attn_maps = res.get("clip_attention_maps")
116
+ if not attn_maps:
117
+ QMessageBox.warning(
118
+ widget, "No attention data",
119
+ "No attention maps available. Re-run inference with 'Collect attention maps' checked."
120
+ )
121
+ return
122
+
123
+ clip_starts = res.get("clip_starts", [])
124
+ classes = getattr(widget, 'classes', [])
125
+
126
+ ignore_label = getattr(widget, 'ignore_label_name', 'Ignored')
127
+
128
+ default_name = os.path.splitext(os.path.basename(video_path))[0] + "_attention.mp4"
129
+ default_dir = os.path.dirname(video_path)
130
+ output_path, _ = QFileDialog.getSaveFileName(
131
+ widget, "Save attention heatmap video",
132
+ os.path.join(default_dir, default_name),
133
+ "MP4 Video (*.mp4);;AVI Video (*.avi)"
134
+ )
135
+ if not output_path:
136
+ return
137
+
138
+ cap = cv2.VideoCapture(video_path)
139
+ if not cap.isOpened():
140
+ QMessageBox.critical(widget, "Error", f"Cannot open video: {video_path}")
141
+ return
142
+
143
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
144
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
145
+ if orig_fps <= 0:
146
+ orig_fps = 30.0
147
+ frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
148
+ frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
149
+
150
+ target_fps = getattr(widget, 'target_fps_spin', None)
151
+ target_fps_val = target_fps.value() if target_fps else orig_fps
152
+ frame_interval = max(1, int(round(orig_fps / max(1e-6, float(target_fps_val)))))
153
+ clip_length = getattr(widget, 'clip_length_spin', None)
154
+ clip_length_val = clip_length.value() if clip_length else 8
155
+ use_classification_roi = bool(
156
+ getattr(widget, "localization_bboxes", None)
157
+ and hasattr(widget, "_get_classification_roi_bbox_for_clip_frame")
158
+ )
159
+ crop_resolution = int(getattr(widget, "infer_resolution", 0) or 0)
160
+ if hasattr(widget, "resolution_spin"):
161
+ try:
162
+ crop_resolution = int(widget.resolution_spin.value())
163
+ except Exception as e:
164
+ logger.debug("Could not read resolution from resolution_spin: %s", e)
165
+ crop_resolution = max(64, crop_resolution or min(frame_w, frame_h))
166
+ output_w = crop_resolution if use_classification_roi else frame_w
167
+ output_h = crop_resolution if use_classification_roi else frame_h
168
+
169
+ # Build sorted list of (video_frame_idx, heatmap_2d) keyed attention frames
170
+ attn_keyframes = {}
171
+ grid_side = None
172
+ for clip_idx, attn in enumerate(attn_maps):
173
+ if attn is None or clip_idx >= len(clip_starts):
174
+ continue
175
+ attn_arr = np.array(attn, dtype=np.float32) # [T_clip, num_heads, S]
176
+ attn_avg = attn_arr.mean(axis=1) # [T_clip, S]
177
+ gs = int(math.isqrt(attn_avg.shape[-1]))
178
+ if gs * gs != attn_avg.shape[-1]:
179
+ continue
180
+ grid_side = gs
181
+
182
+ start_frame = clip_starts[clip_idx]
183
+ for t in range(attn_avg.shape[0]):
184
+ vid_frame = start_frame + t * frame_interval
185
+ if vid_frame < total_frames:
186
+ attn_keyframes[vid_frame] = attn_avg[t].reshape(gs, gs)
187
+
188
+ if not attn_keyframes or grid_side is None:
189
+ cap.release()
190
+ QMessageBox.warning(widget, "No data", "Could not map attention data to video frames.")
191
+ return
192
+
193
+ # Upscale to export resolution so heatmap has maximum detail when overlaid
194
+ interp_size = min(output_w, output_h)
195
+
196
+ # Gaussian kernel to smooth grid boundaries (scale with grid cell size at interp)
197
+ blur_ksize = max(3, min(interp_size // 24, grid_side * 4) | 1)
198
+
199
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
200
+ writer = cv2.VideoWriter(output_path, fourcc, orig_fps, (output_w, output_h))
201
+
202
+ progress = QProgressDialog("Rendering attention heatmap video...", "Cancel", 0, total_frames, widget)
203
+ progress.setWindowModality(Qt.WindowModality.WindowModal)
204
+ progress.setMinimumDuration(0)
205
+
206
+ # Temporal EMA smoothing for heatmap continuity
207
+ ema_decay = 0.7
208
+ smooth_heatmap = None
209
+
210
+ # Scale label text to video size (readable on small and large frames)
211
+ ref_size = 480
212
+ small = min(output_w, output_h)
213
+ text_scale = max(0.35, min(2.0, small / ref_size))
214
+ font_scale = round(text_scale * 10) / 10.0
215
+ thickness = max(1, int(round(text_scale * 2)))
216
+
217
+ for frame_idx in range(total_frames):
218
+ if progress.wasCanceled():
219
+ break
220
+
221
+ ret, frame = cap.read()
222
+ if not ret:
223
+ break
224
+
225
+ clip_idx = None
226
+ if clip_starts:
227
+ clip_idx = _find_clip_for_frame(frame_idx, clip_starts, clip_length_val, frame_interval)
228
+ if use_classification_roi and clip_idx is not None:
229
+ roi_bbox = widget._get_classification_roi_bbox_for_clip_frame(clip_idx)
230
+ base_frame = _crop_frame_to_roi(frame, roi_bbox, output_w, output_h)
231
+ elif use_classification_roi:
232
+ base_frame = cv2.resize(frame, (output_w, output_h), interpolation=cv2.INTER_LINEAR)
233
+ else:
234
+ base_frame = frame
235
+
236
+ if frame_idx in attn_keyframes:
237
+ raw = attn_keyframes[frame_idx]
238
+ if smooth_heatmap is None:
239
+ smooth_heatmap = raw.copy()
240
+ else:
241
+ smooth_heatmap = ema_decay * smooth_heatmap + (1 - ema_decay) * raw
242
+ # else: keep previous smooth_heatmap (carry forward)
243
+
244
+ if smooth_heatmap is None:
245
+ writer.write(base_frame)
246
+ if frame_idx % 200 == 0:
247
+ progress.setValue(frame_idx)
248
+ continue
249
+
250
+ # Normalize to 0-1
251
+ h_min, h_max = smooth_heatmap.min(), smooth_heatmap.max()
252
+ if h_max > h_min:
253
+ heatmap_01 = (smooth_heatmap - h_min) / (h_max - h_min)
254
+ else:
255
+ heatmap_01 = np.zeros_like(smooth_heatmap)
256
+
257
+ # Upscale to intermediate resolution with bicubic
258
+ heatmap_up = cv2.resize(heatmap_01.astype(np.float32), (interp_size, interp_size),
259
+ interpolation=cv2.INTER_CUBIC)
260
+
261
+ # Gaussian blur for smooth appearance
262
+ heatmap_up = cv2.GaussianBlur(heatmap_up, (blur_ksize, blur_ksize), 0)
263
+
264
+ # Final resize to video dimensions
265
+ heatmap_final = cv2.resize(heatmap_up, (output_w, output_h),
266
+ interpolation=cv2.INTER_LINEAR)
267
+
268
+ # Clip to valid range after interpolation
269
+ heatmap_final = np.clip(heatmap_final, 0, 1)
270
+ heatmap_u8 = (heatmap_final * 255).astype(np.uint8)
271
+
272
+ heatmap_color = cv2.applyColorMap(heatmap_u8, cv2.COLORMAP_JET)
273
+
274
+ blended = cv2.addWeighted(base_frame, 0.6, heatmap_color, 0.4, 0)
275
+
276
+ label, color = _get_final_label_for_frame(
277
+ widget, frame_idx, clip_idx, clip_starts, clip_length_val, frame_interval, classes, ignore_label
278
+ )
279
+ if label:
280
+ x_label = max(8, int(round(10 * text_scale)))
281
+ y_label = max(20, int(round(30 * text_scale)))
282
+ cv2.putText(blended, label, (x_label, y_label), cv2.FONT_HERSHEY_SIMPLEX,
283
+ font_scale, color, thickness, cv2.LINE_AA)
284
+
285
+ writer.write(blended)
286
+
287
+ if frame_idx % 200 == 0:
288
+ progress.setValue(frame_idx)
289
+
290
+ progress.setValue(total_frames)
291
+ cap.release()
292
+ writer.release()
293
+
294
+ if not progress.wasCanceled():
295
+ QMessageBox.information(
296
+ widget, "Export complete",
297
+ f"Attention heatmap video saved to:\n{output_path}"
298
+ )
299
+
300
+ from .overlay_export import VideoPreviewDialog
301
+ dialog = VideoPreviewDialog(output_path, parent=widget)
302
+ dialog.exec()
303
+
304
+
305
+ def _find_clip_for_frame(frame_idx, clip_starts, clip_length, frame_interval):
306
+ """Find which clip index covers the given video frame."""
307
+ for i in range(len(clip_starts) - 1, -1, -1):
308
+ clip_end = clip_starts[i] + clip_length * frame_interval
309
+ if clip_starts[i] <= frame_idx < clip_end:
310
+ return i
311
+ return None