singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,688 @@
1
+ """Video processing with mask-based background removal and centering."""
2
+ import logging
3
+ import os
4
+ import cv2
5
+ import h5py
6
+ import numpy as np
7
+ from typing import Optional, Tuple
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def load_segmentation_data(file_path: str) -> dict:
13
+ if not file_path.lower().endswith((".h5", ".hdf5")):
14
+ raise ValueError(f"Unsupported segmentation format: {file_path}")
15
+
16
+ with h5py.File(file_path, "r") as f:
17
+ frame_objects_group = f["frame_objects"]
18
+ frame_count = int(f.attrs.get("total_frames", len(frame_objects_group)))
19
+ frame_objects = []
20
+ for frame_idx in range(frame_count):
21
+ frame_key = f"frame_{frame_idx:06d}"
22
+ frame_objs = []
23
+ if frame_key in frame_objects_group:
24
+ frame_group = frame_objects_group[frame_key]
25
+ for obj_key in sorted(frame_group.keys()):
26
+ obj_group = frame_group[obj_key]
27
+ bbox = tuple(int(v) for v in obj_group.attrs["bbox"])
28
+ obj_id = int(obj_group.attrs["obj_id"])
29
+ mask = obj_group["mask"][()].astype(bool)
30
+ frame_objs.append({
31
+ "bbox": bbox,
32
+ "mask": mask,
33
+ "obj_id": obj_id,
34
+ })
35
+ frame_objects.append(frame_objs)
36
+
37
+ return {
38
+ "video_path": f.attrs.get("video_path", file_path),
39
+ "total_frames": frame_count,
40
+ "height": int(f.attrs["height"]),
41
+ "width": int(f.attrs["width"]),
42
+ "fps": float(f.attrs.get("fps", 30.0)),
43
+ "frame_objects": frame_objects,
44
+ "objects_per_frame": [int(v) for v in f["objects_per_frame"][()]],
45
+ "tracker": {},
46
+ "format": "hdf5",
47
+ "start_offset": int(f.attrs.get("start_offset", 0)),
48
+ "original_total_frames": int(f.attrs.get("original_total_frames", frame_count)),
49
+ }
50
+
51
+
52
+ def save_segmentation_data(file_path: str, mask_data: dict) -> None:
53
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
54
+ frame_objects = mask_data.get("frame_objects", [])
55
+ with h5py.File(file_path, "w") as f:
56
+ f.attrs["video_path"] = str(mask_data.get("video_path", ""))
57
+ f.attrs["total_frames"] = int(mask_data.get("total_frames", len(frame_objects)))
58
+ f.attrs["height"] = int(mask_data.get("height", 0))
59
+ f.attrs["width"] = int(mask_data.get("width", 0))
60
+ f.attrs["fps"] = float(mask_data.get("fps", 30.0))
61
+ f.attrs["start_offset"] = int(mask_data.get("start_offset", 0))
62
+ f.attrs["original_total_frames"] = int(
63
+ mask_data.get("original_total_frames", len(frame_objects))
64
+ )
65
+
66
+ objects_per_frame = np.asarray(
67
+ mask_data.get("objects_per_frame", [len(objs) for objs in frame_objects]),
68
+ dtype=np.int32,
69
+ )
70
+ f.create_dataset("objects_per_frame", data=objects_per_frame, compression="gzip")
71
+
72
+ frames_group = f.create_group("frame_objects")
73
+ for frame_idx, objs in enumerate(frame_objects):
74
+ frame_group = frames_group.create_group(f"frame_{frame_idx:06d}")
75
+ for obj_idx, obj in enumerate(objs):
76
+ obj_group = frame_group.create_group(f"obj_{obj_idx:03d}")
77
+ obj_group.attrs["obj_id"] = int(obj.get("obj_id", 0))
78
+ obj_group.attrs["bbox"] = np.asarray(obj.get("bbox", (0, 0, 0, 0)), dtype=np.int32)
79
+ mask = np.asarray(obj.get("mask", np.zeros((1, 1), dtype=bool)), dtype=np.uint8)
80
+ obj_group.create_dataset(
81
+ "mask",
82
+ data=mask,
83
+ compression="gzip",
84
+ compression_opts=4,
85
+ shuffle=True,
86
+ )
87
+
88
+
89
+ def convert_old_format_to_objects(data: np.ndarray) -> list:
90
+ """Convert old format mask array to frame_objects list."""
91
+ frames, channels, height, width = data.shape
92
+ frame_objects = []
93
+ for frame_idx in range(frames):
94
+ frame_mask = data[frame_idx, 0, :, :]
95
+ if np.any(frame_mask):
96
+ rows, cols = np.where(frame_mask)
97
+ if len(rows) > 0 and len(cols) > 0:
98
+ y_min, y_max = np.min(rows), np.max(rows)
99
+ x_min, x_max = np.min(cols), np.max(cols)
100
+ obj_mask = frame_mask[y_min:y_max+1, x_min:x_max+1]
101
+ obj = {
102
+ 'bbox': (x_min, y_min, x_max, y_max),
103
+ 'mask': obj_mask.astype(bool),
104
+ 'obj_id': 0
105
+ }
106
+ frame_objects.append([obj])
107
+ else:
108
+ frame_objects.append([])
109
+ else:
110
+ frame_objects.append([])
111
+ return frame_objects
112
+
113
+
114
+ def read_video_frames(video_path: str, expected_frames: Optional[int] = None) -> list:
115
+ cap = cv2.VideoCapture(video_path)
116
+ if not cap.isOpened():
117
+ raise ValueError(f"Could not open video: {video_path}")
118
+
119
+ fps = cap.get(cv2.CAP_PROP_FPS)
120
+ frames = []
121
+
122
+ while True:
123
+ if expected_frames is not None and len(frames) >= expected_frames:
124
+ break
125
+ ret, frame = cap.read()
126
+ if not ret:
127
+ break
128
+ frames.append(frame)
129
+
130
+ cap.release()
131
+ return frames, fps
132
+
133
+
134
+ def process_frame_with_mask(
135
+ frame_bgr: np.ndarray,
136
+ frame_objects: list,
137
+ frame_idx: int,
138
+ box_size: int = 250,
139
+ target_size: int = 288,
140
+ background_mode: str = 'white',
141
+ normalization_method: str = 'CLAHE',
142
+ mask_feather_px: int = 0,
143
+ anchor_cx: Optional[float] = None,
144
+ anchor_cy: Optional[float] = None,
145
+ anchor_mode: str = 'frame',
146
+ obj_id: Optional[int] = None
147
+ ) -> np.ndarray:
148
+ """
149
+ Process a single frame: crop around centroid and remove background.
150
+
151
+ Args:
152
+ obj_id: If provided, only process this object ID. Otherwise, use first object.
153
+ normalization_method: 'CLAHE', 'Histogram Equalization', 'Mean-Variance', or 'None'
154
+
155
+ Returns processed frame (target_size, target_size, 3) in [0,1] float32.
156
+ """
157
+ h, w = frame_bgr.shape[:2]
158
+
159
+ if frame_idx >= len(frame_objects) or not frame_objects[frame_idx]:
160
+ return np.zeros((target_size, target_size, 3), dtype=np.float32)
161
+
162
+ obj = None
163
+ if obj_id is not None:
164
+ for o in frame_objects[frame_idx]:
165
+ o_id = o.get('obj_id')
166
+ if str(o_id) == str(obj_id):
167
+ obj = o
168
+ break
169
+ if obj is None:
170
+ return np.zeros((target_size, target_size, 3), dtype=np.float32)
171
+ else:
172
+ obj = frame_objects[frame_idx][0]
173
+ mask = obj['mask']
174
+ x_min, y_min, x_max, y_max = obj['bbox']
175
+
176
+ full_mask = np.zeros((h, w), dtype=np.uint8)
177
+ bbox_h = max(0, y_max - y_min + 1)
178
+ bbox_w = max(0, x_max - x_min + 1)
179
+
180
+ if bbox_h > 0 and bbox_w > 0:
181
+ mh, mw = mask.shape
182
+ h_use = min(bbox_h, mh, h - max(0, y_min))
183
+ w_use = min(bbox_w, mw, w - max(0, x_min))
184
+ if h_use > 0 and w_use > 0:
185
+ try:
186
+ full_mask[y_min:y_min + h_use, x_min:x_min + w_use] = (
187
+ mask[:h_use, :w_use] > 0
188
+ ).astype(np.uint8)
189
+ except Exception:
190
+ pass
191
+
192
+ m = cv2.moments(full_mask, binaryImage=True)
193
+ if m['m00'] > 0:
194
+ cx = m['m10'] / m['m00']
195
+ cy = m['m01'] / m['m00']
196
+ else:
197
+ cx = (x_min + x_max) / 2.0
198
+ cy = (y_min + y_max) / 2.0
199
+
200
+ if anchor_mode == 'first' and anchor_cx is not None and anchor_cy is not None:
201
+ cx_use, cy_use = anchor_cx, anchor_cy
202
+ else:
203
+ cx_use, cy_use = cx, cy
204
+
205
+ half = box_size // 2
206
+ desired_x1 = int(round(cx_use)) - half
207
+ desired_y1 = int(round(cy_use)) - half
208
+ desired_x2 = desired_x1 + box_size
209
+ desired_y2 = desired_y1 + box_size
210
+
211
+ src_x1 = max(0, desired_x1)
212
+ src_y1 = max(0, desired_y1)
213
+ src_x2 = min(w, desired_x2)
214
+ src_y2 = min(h, desired_y2)
215
+
216
+ if src_x2 <= src_x1 or src_y2 <= src_y1:
217
+ crop_rgb = np.zeros((box_size, box_size, 3), dtype=np.uint8)
218
+ crop_mask = np.zeros((box_size, box_size), dtype=np.uint8)
219
+ else:
220
+ crop_bgr = frame_bgr[src_y1:src_y2, src_x1:src_x2]
221
+ crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
222
+ crop_mask = full_mask[src_y1:src_y2, src_x1:src_x2]
223
+
224
+ pad_left = src_x1 - desired_x1
225
+ pad_top = src_y1 - desired_y1
226
+ pad_right = desired_x2 - src_x2
227
+ pad_bottom = desired_y2 - src_y2
228
+
229
+ if pad_left or pad_top or pad_right or pad_bottom:
230
+ padded_rgb = np.zeros((box_size, box_size, 3), dtype=np.uint8)
231
+ padded_mask = np.zeros((box_size, box_size), dtype=np.uint8)
232
+ h_c, w_c = crop_rgb.shape[:2]
233
+ padded_rgb[pad_top:pad_top + h_c, pad_left:pad_left + w_c] = crop_rgb
234
+ padded_mask[pad_top:pad_top + h_c, pad_left:pad_left + w_c] = crop_mask
235
+ crop_rgb, crop_mask = padded_rgb, padded_mask
236
+
237
+ if normalization_method == 'CLAHE':
238
+ gray = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2GRAY)
239
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
240
+ gray_eq = clahe.apply(gray)
241
+ crop_rgb = cv2.cvtColor(gray_eq, cv2.COLOR_GRAY2RGB)
242
+
243
+ elif normalization_method == 'Histogram Equalization':
244
+ gray = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2GRAY)
245
+ gray_eq = cv2.equalizeHist(gray)
246
+ crop_rgb = cv2.cvtColor(gray_eq, cv2.COLOR_GRAY2RGB)
247
+
248
+ elif normalization_method == 'Mean-Variance':
249
+ gray = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2GRAY).astype(np.float32)
250
+ mean, std = cv2.meanStdDev(gray)
251
+ if std[0][0] > 0:
252
+ gray = (gray - mean[0][0]) / std[0][0] * 50 + 127
253
+ else:
254
+ gray = gray - mean[0][0] + 127
255
+ gray = np.clip(gray, 0, 255).astype(np.uint8)
256
+ crop_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
257
+
258
+ if crop_rgb.shape[0] != box_size or crop_rgb.shape[1] != box_size:
259
+ pad_img = np.zeros((box_size, box_size, 3), dtype=np.uint8)
260
+ h_c, w_c = crop_rgb.shape[:2]
261
+ pad_img[:h_c, :w_c] = crop_rgb
262
+ crop_rgb = pad_img
263
+
264
+ pad_mask = np.zeros((box_size, box_size), dtype=np.uint8)
265
+ if 'crop_mask' in locals():
266
+ pad_mask[:h_c, :w_c] = crop_mask
267
+ crop_mask = pad_mask
268
+
269
+ if background_mode in ('black', 'gray', 'blur', 'white'):
270
+ m_float = crop_mask.astype(np.float32)
271
+ if mask_feather_px and mask_feather_px > 0:
272
+ k = max(1, int(mask_feather_px) | 1)
273
+ m_float = cv2.GaussianBlur(m_float, (k, k), 0)
274
+ m_float = np.clip(m_float, 0.0, 1.0)
275
+ m3 = np.repeat(m_float[:, :, None], 3, axis=2)
276
+
277
+ if background_mode == 'black':
278
+ bg = np.zeros_like(crop_rgb)
279
+ elif background_mode == 'gray':
280
+ bg = np.full_like(crop_rgb, 128)
281
+ elif background_mode == 'white':
282
+ bg = np.full_like(crop_rgb, 255)
283
+ else: # blur
284
+ bg = cv2.GaussianBlur(crop_rgb, (11, 11), 0)
285
+
286
+ crop_rgb = (m3 * crop_rgb + (1.0 - m3) * bg).astype(np.uint8)
287
+
288
+ # Resize to target size
289
+ frame_rgb = cv2.resize(crop_rgb, (target_size, target_size), interpolation=cv2.INTER_AREA)
290
+
291
+ # Normalize to [0, 1]
292
+ return (frame_rgb.astype(np.float32) / 255.0)
293
+
294
+
295
+ def process_video(
296
+ video_path: str,
297
+ mask_path: str,
298
+ output_path: str,
299
+ box_size: int = 250,
300
+ target_size: int = 288,
301
+ background_mode: str = 'white',
302
+ mask_feather_px: int = 0,
303
+ anchor_mode: str = 'first',
304
+ progress_callback: Optional[callable] = None,
305
+ obj_id: Optional[int] = None
306
+ ) -> bool:
307
+ """
308
+ Process entire video: load mask, process each frame, save output.
309
+
310
+ Args:
311
+ video_path: Path to input video
312
+ mask_path: Path to HDF5 mask file
313
+ output_path: Path to save processed video (or base path if obj_id is None and multiple objects exist)
314
+ box_size: Size of crop box in pixels
315
+ target_size: Final output size
316
+ background_mode: 'white', 'black', 'gray', 'blur', 'none'
317
+ mask_feather_px: Feathering radius for mask edges
318
+ anchor_mode: 'frame' (per-frame centroid) or 'first' (fixed at first frame)
319
+ progress_callback: Optional function(frame_num, total_frames, obj_id) for progress updates
320
+ obj_id: If provided, only process this object ID. If None and multiple objects exist, creates separate videos.
321
+
322
+ Returns:
323
+ True if successful, False otherwise. If multiple objects, returns list of output paths.
324
+ """
325
+ try:
326
+ mask_data = load_segmentation_data(mask_path)
327
+ frame_objects = mask_data['frame_objects']
328
+ num_frames = len(frame_objects)
329
+ start_offset = mask_data.get('start_offset', 0)
330
+
331
+ all_obj_ids = set()
332
+ for frame_objs in frame_objects:
333
+ for obj in frame_objs:
334
+ obj_id = obj.get('obj_id', 0)
335
+ all_obj_ids.add(obj_id)
336
+ all_obj_ids = sorted(list(all_obj_ids))
337
+
338
+ if obj_id is None and len(all_obj_ids) > 1:
339
+ output_paths = []
340
+ base_path = os.path.splitext(output_path)[0]
341
+ ext = os.path.splitext(output_path)[1]
342
+
343
+ for oid in all_obj_ids:
344
+ obj_output_path = f"{base_path}_obj{oid}{ext}"
345
+ success = process_video(
346
+ video_path, mask_path, obj_output_path,
347
+ box_size, target_size, background_mode,
348
+ mask_feather_px, anchor_mode, progress_callback, obj_id=oid
349
+ )
350
+ if success:
351
+ output_paths.append(obj_output_path)
352
+
353
+ return output_paths if output_paths else False
354
+
355
+ if obj_id is None:
356
+ obj_id = all_obj_ids[0] if all_obj_ids else None
357
+
358
+ raw_frames, fps = read_video_frames(video_path, expected_frames=num_frames)
359
+ if len(raw_frames) != num_frames:
360
+ logger.warning("Video has %d frames, mask has %d", len(raw_frames), num_frames)
361
+ num_frames = min(len(raw_frames), num_frames)
362
+
363
+ anchor_cx = None
364
+ anchor_cy = None
365
+ if anchor_mode == 'first':
366
+ for frame_idx in range(min(10, num_frames)): # Check first 10 frames
367
+ if frame_idx < len(frame_objects) and frame_objects[frame_idx]:
368
+ obj0 = None
369
+ for o in frame_objects[frame_idx]:
370
+ if o.get('obj_id') == obj_id:
371
+ obj0 = o
372
+ break
373
+ if obj0 is None:
374
+ continue
375
+
376
+ mask0 = obj0['mask']
377
+ x0_min, y0_min, x0_max, y0_max = obj0['bbox']
378
+ frame_bgr0 = raw_frames[frame_idx]
379
+ h0, w0 = frame_bgr0.shape[:2]
380
+ full_mask0 = np.zeros((h0, w0), dtype=np.uint8)
381
+ bbox_h0 = max(0, y0_max - y0_min + 1)
382
+ bbox_w0 = max(0, x0_max - x0_min + 1)
383
+ if bbox_h0 > 0 and bbox_w0 > 0:
384
+ mh0, mw0 = mask0.shape
385
+ h0_use = min(bbox_h0, mh0, h0 - max(0, y0_min))
386
+ w0_use = min(bbox_w0, mw0, w0 - max(0, x0_min))
387
+ if h0_use > 0 and w0_use > 0:
388
+ try:
389
+ full_mask0[y0_min:y0_min + h0_use, x0_min:x0_min + w0_use] = (
390
+ mask0[:h0_use, :w0_use] > 0
391
+ ).astype(np.uint8)
392
+ except Exception:
393
+ pass
394
+ m0 = cv2.moments(full_mask0, binaryImage=True)
395
+ if m0['m00'] > 0:
396
+ anchor_cx = m0['m10'] / m0['m00']
397
+ anchor_cy = m0['m01'] / m0['m00']
398
+ break
399
+ else:
400
+ anchor_cx = (x0_min + x0_max) / 2.0
401
+ anchor_cy = (y0_min + y0_max) / 2.0
402
+ break
403
+
404
+ # Use mp4v codec for efficient encoding
405
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
406
+ out = cv2.VideoWriter(output_path, fourcc, fps, (target_size, target_size), isColor=True)
407
+
408
+ if not out.isOpened():
409
+ raise ValueError(f"Failed to open video writer for {output_path}")
410
+
411
+ try:
412
+ for frame_idx in range(num_frames):
413
+ if progress_callback:
414
+ if callable(progress_callback):
415
+ try:
416
+ progress_callback(frame_idx + 1, num_frames, obj_id)
417
+ except TypeError:
418
+ progress_callback(frame_idx + 1, num_frames)
419
+
420
+ frame_bgr = raw_frames[frame_idx]
421
+
422
+ frame_processed = process_frame_with_mask(
423
+ frame_bgr,
424
+ frame_objects,
425
+ frame_idx,
426
+ box_size=box_size,
427
+ target_size=target_size,
428
+ background_mode=background_mode,
429
+ mask_feather_px=mask_feather_px,
430
+ anchor_cx=anchor_cx,
431
+ anchor_cy=anchor_cy,
432
+ anchor_mode=anchor_mode,
433
+ obj_id=obj_id
434
+ )
435
+
436
+ frame_bgr_out = (frame_processed * 255.0).astype(np.uint8)
437
+ frame_bgr_out = cv2.cvtColor(frame_bgr_out, cv2.COLOR_RGB2BGR)
438
+ out.write(frame_bgr_out)
439
+
440
+ return True
441
+ finally:
442
+ out.release()
443
+
444
+ except Exception as e:
445
+ logger.error("Error processing video: %s", e, exc_info=True)
446
+ return False
447
+
448
+
449
+ def process_video_to_clips(
450
+ video_path: str,
451
+ mask_path: str,
452
+ output_dir: str,
453
+ box_size: int = 250,
454
+ target_size: int = 288,
455
+ background_mode: str = 'white',
456
+ normalization_method: str = 'CLAHE',
457
+ mask_feather_px: int = 0,
458
+ anchor_mode: str = 'first',
459
+ target_fps: int = 16,
460
+ clip_length_frames: int = 16,
461
+ step_frames: int = 16,
462
+ progress_callback: Optional[callable] = None,
463
+ obj_id: Optional[int] = None
464
+ ) -> list:
465
+ """
466
+ Process video into clips with registration: load mask, process frames, save as clips.
467
+
468
+ Args:
469
+ video_path: Path to input video
470
+ mask_path: Path to HDF5 mask file
471
+ output_dir: Directory to save clips
472
+ box_size: Size of crop box in pixels
473
+ target_size: Final output size
474
+ background_mode: 'white', 'black', 'gray', 'blur', 'none'
475
+ normalization_method: 'CLAHE', 'Histogram Equalization', 'Mean-Variance', or 'None'
476
+ mask_feather_px: Feathering radius for mask edges
477
+ anchor_mode: 'frame' (per-frame centroid) or 'first' (fixed at first frame)
478
+ target_fps: Target FPS for clips (frames will be subsampled)
479
+ clip_length_frames: Number of frames per clip
480
+ step_frames: Step size between clips
481
+ progress_callback: Optional function(clip_num, total_clips, obj_id) for progress updates
482
+ obj_id: If provided, only process this object ID. If None and multiple objects exist, creates separate clips.
483
+
484
+ Returns:
485
+ List of tuples: (clip_path, start_frame, end_frame). If multiple objects, returns list of lists.
486
+ """
487
+ try:
488
+ # Load mask data
489
+ mask_data = load_segmentation_data(mask_path)
490
+ frame_objects = mask_data['frame_objects']
491
+ num_frames = len(frame_objects)
492
+ # Get start_offset if masks were trimmed to a range
493
+ start_offset = mask_data.get('start_offset', 0)
494
+
495
+ all_obj_ids = set()
496
+ for frame_objs in frame_objects:
497
+ for obj in frame_objs:
498
+ oid = obj.get('obj_id', 0)
499
+ try:
500
+ all_obj_ids.add(int(oid))
501
+ except (ValueError, TypeError):
502
+ all_obj_ids.add(oid)
503
+ try:
504
+ all_obj_ids = sorted(list(all_obj_ids), key=lambda x: int(x))
505
+ except:
506
+ all_obj_ids = sorted(list(all_obj_ids), key=str)
507
+
508
+ if obj_id is None and len(all_obj_ids) > 1:
509
+ logger.info("Processing %d objects: %s", len(all_obj_ids), all_obj_ids)
510
+ all_clip_paths = []
511
+ for oid in all_obj_ids:
512
+ logger.info("Starting processing for object %s...", oid)
513
+ try:
514
+ clip_paths = process_video_to_clips(
515
+ video_path, mask_path, output_dir,
516
+ box_size, target_size, background_mode,
517
+ normalization_method,
518
+ mask_feather_px, anchor_mode,
519
+ target_fps, clip_length_frames, step_frames,
520
+ progress_callback, obj_id=oid
521
+ )
522
+ all_clip_paths.extend(clip_paths)
523
+ logger.info("Completed object %s: %d clips generated.", oid, len(clip_paths))
524
+ except Exception as e:
525
+ logger.error("Error processing object %s: %s", oid, e, exc_info=True)
526
+ return all_clip_paths
527
+
528
+ if obj_id is None:
529
+ obj_id = all_obj_ids[0] if all_obj_ids else None
530
+
531
+ os.makedirs(output_dir, exist_ok=True)
532
+
533
+ cap = cv2.VideoCapture(video_path)
534
+ if not cap.isOpened():
535
+ raise ValueError(f"Could not open video: {video_path}")
536
+
537
+ try:
538
+ if start_offset > 0:
539
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_offset)
540
+
541
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
542
+ if orig_fps <= 0:
543
+ orig_fps = 30.0
544
+
545
+ frame_interval = max(1, int(round(orig_fps / target_fps)))
546
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_offset)
547
+
548
+ clip_paths = [] # (clip_path, start_frame, end_frame)
549
+ frame_idx = start_offset # Original video frame index
550
+ clip_idx = 0
551
+ frames_buffer = []
552
+ skip_remaining = 0
553
+ frame_indices_in_buffer = [] # Track original frame indices for each frame in buffer
554
+ clip_start_mask_frame_idx = None # Track first frame (original index)
555
+
556
+ # Persistent anchor for current clip (used when anchor_mode='first')
557
+ clip_anchor_cx = None
558
+ clip_anchor_cy = None
559
+
560
+ end_frame_limit = start_offset + num_frames # process only the trimmed range
561
+
562
+ while True:
563
+ if frame_idx >= end_frame_limit:
564
+ break
565
+
566
+ ret, frame_bgr = cap.read()
567
+ if not ret:
568
+ break
569
+
570
+ if frame_idx % frame_interval == 0:
571
+ if skip_remaining > 0:
572
+ skip_remaining -= 1
573
+ frame_idx += 1
574
+ continue
575
+ mask_frame_idx = frame_idx - start_offset
576
+ if mask_frame_idx < 0 or mask_frame_idx >= len(frame_objects):
577
+ frame_idx += 1
578
+ continue
579
+
580
+ if len(frames_buffer) == 0:
581
+ clip_start_mask_frame_idx = mask_frame_idx
582
+ clip_anchor_cx = None
583
+ clip_anchor_cy = None
584
+
585
+ if anchor_mode == 'first' and mask_frame_idx < len(frame_objects) and frame_objects[mask_frame_idx]:
586
+ obj0 = None
587
+ for o in frame_objects[mask_frame_idx]:
588
+ o_id = o.get('obj_id')
589
+ if str(o_id) == str(obj_id):
590
+ obj0 = o
591
+ break
592
+ if obj0:
593
+ mask0 = obj0['mask']
594
+ x0_min, y0_min, x0_max, y0_max = obj0['bbox']
595
+ h0, w0 = frame_bgr.shape[:2]
596
+ full_mask0 = np.zeros((h0, w0), dtype=np.uint8)
597
+ bbox_h0 = max(0, y0_max - y0_min + 1)
598
+ bbox_w0 = max(0, x0_max - x0_min + 1)
599
+ if bbox_h0 > 0 and bbox_w0 > 0:
600
+ mh0, mw0 = mask0.shape
601
+ h0_use = min(bbox_h0, mh0, h0 - max(0, y0_min))
602
+ w0_use = min(bbox_w0, mw0, w0 - max(0, x0_min))
603
+ if h0_use > 0 and w0_use > 0:
604
+ try:
605
+ full_mask0[y0_min:y0_min + h0_use, x0_min:x0_min + w0_use] = (
606
+ mask0[:h0_use, :w0_use] > 0
607
+ ).astype(np.uint8)
608
+ except Exception:
609
+ pass
610
+ m0 = cv2.moments(full_mask0, binaryImage=True)
611
+ if m0['m00'] > 0:
612
+ clip_anchor_cx = m0['m10'] / m0['m00']
613
+ clip_anchor_cy = m0['m01'] / m0['m00']
614
+ else:
615
+ clip_anchor_cx = (x0_min + x0_max) / 2.0
616
+ clip_anchor_cy = (y0_min + y0_max) / 2.0
617
+
618
+ frame_processed = process_frame_with_mask(
619
+ frame_bgr,
620
+ frame_objects,
621
+ mask_frame_idx,
622
+ box_size=box_size,
623
+ target_size=target_size,
624
+ background_mode=background_mode,
625
+ normalization_method=normalization_method,
626
+ mask_feather_px=mask_feather_px,
627
+ anchor_cx=clip_anchor_cx,
628
+ anchor_cy=clip_anchor_cy,
629
+ anchor_mode=anchor_mode,
630
+ obj_id=obj_id
631
+ )
632
+
633
+ frame_bgr_out = (frame_processed * 255.0).astype(np.uint8)
634
+ frame_bgr_out = cv2.cvtColor(frame_bgr_out, cv2.COLOR_RGB2BGR)
635
+
636
+ frames_buffer.append(frame_bgr_out)
637
+ frame_indices_in_buffer.append(frame_idx)
638
+
639
+ if len(frames_buffer) == clip_length_frames:
640
+ video_basename = os.path.splitext(os.path.basename(video_path))[0]
641
+ clip_name = f"{video_basename}_clip_{clip_idx:06d}"
642
+ if obj_id is not None:
643
+ clip_name += f"_obj{obj_id}"
644
+ clip_path = os.path.join(output_dir, f"{clip_name}.mp4")
645
+
646
+ clip_start_mask_frame_idx = frame_indices_in_buffer[0]
647
+ clip_end_mask_frame_idx = frame_indices_in_buffer[-1]
648
+
649
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
650
+ out = cv2.VideoWriter(clip_path, fourcc, target_fps, (target_size, target_size), isColor=True)
651
+
652
+ if out.isOpened():
653
+ for f in frames_buffer:
654
+ out.write(f)
655
+ out.release()
656
+ clip_paths.append((clip_path, clip_start_mask_frame_idx, clip_end_mask_frame_idx))
657
+
658
+ if progress_callback:
659
+ try:
660
+ progress_callback(clip_idx + 1, None, obj_id)
661
+ except TypeError:
662
+ try:
663
+ progress_callback(clip_idx + 1, None)
664
+ except TypeError:
665
+ pass
666
+
667
+ clip_idx += 1
668
+
669
+ if step_frames < clip_length_frames:
670
+ frames_buffer = frames_buffer[clip_length_frames - step_frames:]
671
+ frame_indices_in_buffer = frame_indices_in_buffer[clip_length_frames - step_frames:]
672
+ clip_start_mask_frame_idx = frame_indices_in_buffer[0] if frame_indices_in_buffer else None
673
+ else:
674
+ frames_buffer = []
675
+ frame_indices_in_buffer = []
676
+ clip_start_mask_frame_idx = None
677
+ skip_remaining = max(0, step_frames - clip_length_frames)
678
+
679
+ frame_idx += 1
680
+
681
+ return clip_paths
682
+ finally:
683
+ cap.release()
684
+
685
+ except Exception as e:
686
+ logger.error("Error processing video to clips: %s", e, exc_info=True)
687
+ return []
688
+