singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,688 @@
|
|
|
1
|
+
"""Video processing with mask-based background removal and centering."""
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import cv2
|
|
5
|
+
import h5py
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_segmentation_data(file_path: str) -> dict:
|
|
13
|
+
if not file_path.lower().endswith((".h5", ".hdf5")):
|
|
14
|
+
raise ValueError(f"Unsupported segmentation format: {file_path}")
|
|
15
|
+
|
|
16
|
+
with h5py.File(file_path, "r") as f:
|
|
17
|
+
frame_objects_group = f["frame_objects"]
|
|
18
|
+
frame_count = int(f.attrs.get("total_frames", len(frame_objects_group)))
|
|
19
|
+
frame_objects = []
|
|
20
|
+
for frame_idx in range(frame_count):
|
|
21
|
+
frame_key = f"frame_{frame_idx:06d}"
|
|
22
|
+
frame_objs = []
|
|
23
|
+
if frame_key in frame_objects_group:
|
|
24
|
+
frame_group = frame_objects_group[frame_key]
|
|
25
|
+
for obj_key in sorted(frame_group.keys()):
|
|
26
|
+
obj_group = frame_group[obj_key]
|
|
27
|
+
bbox = tuple(int(v) for v in obj_group.attrs["bbox"])
|
|
28
|
+
obj_id = int(obj_group.attrs["obj_id"])
|
|
29
|
+
mask = obj_group["mask"][()].astype(bool)
|
|
30
|
+
frame_objs.append({
|
|
31
|
+
"bbox": bbox,
|
|
32
|
+
"mask": mask,
|
|
33
|
+
"obj_id": obj_id,
|
|
34
|
+
})
|
|
35
|
+
frame_objects.append(frame_objs)
|
|
36
|
+
|
|
37
|
+
return {
|
|
38
|
+
"video_path": f.attrs.get("video_path", file_path),
|
|
39
|
+
"total_frames": frame_count,
|
|
40
|
+
"height": int(f.attrs["height"]),
|
|
41
|
+
"width": int(f.attrs["width"]),
|
|
42
|
+
"fps": float(f.attrs.get("fps", 30.0)),
|
|
43
|
+
"frame_objects": frame_objects,
|
|
44
|
+
"objects_per_frame": [int(v) for v in f["objects_per_frame"][()]],
|
|
45
|
+
"tracker": {},
|
|
46
|
+
"format": "hdf5",
|
|
47
|
+
"start_offset": int(f.attrs.get("start_offset", 0)),
|
|
48
|
+
"original_total_frames": int(f.attrs.get("original_total_frames", frame_count)),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def save_segmentation_data(file_path: str, mask_data: dict) -> None:
|
|
53
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
54
|
+
frame_objects = mask_data.get("frame_objects", [])
|
|
55
|
+
with h5py.File(file_path, "w") as f:
|
|
56
|
+
f.attrs["video_path"] = str(mask_data.get("video_path", ""))
|
|
57
|
+
f.attrs["total_frames"] = int(mask_data.get("total_frames", len(frame_objects)))
|
|
58
|
+
f.attrs["height"] = int(mask_data.get("height", 0))
|
|
59
|
+
f.attrs["width"] = int(mask_data.get("width", 0))
|
|
60
|
+
f.attrs["fps"] = float(mask_data.get("fps", 30.0))
|
|
61
|
+
f.attrs["start_offset"] = int(mask_data.get("start_offset", 0))
|
|
62
|
+
f.attrs["original_total_frames"] = int(
|
|
63
|
+
mask_data.get("original_total_frames", len(frame_objects))
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
objects_per_frame = np.asarray(
|
|
67
|
+
mask_data.get("objects_per_frame", [len(objs) for objs in frame_objects]),
|
|
68
|
+
dtype=np.int32,
|
|
69
|
+
)
|
|
70
|
+
f.create_dataset("objects_per_frame", data=objects_per_frame, compression="gzip")
|
|
71
|
+
|
|
72
|
+
frames_group = f.create_group("frame_objects")
|
|
73
|
+
for frame_idx, objs in enumerate(frame_objects):
|
|
74
|
+
frame_group = frames_group.create_group(f"frame_{frame_idx:06d}")
|
|
75
|
+
for obj_idx, obj in enumerate(objs):
|
|
76
|
+
obj_group = frame_group.create_group(f"obj_{obj_idx:03d}")
|
|
77
|
+
obj_group.attrs["obj_id"] = int(obj.get("obj_id", 0))
|
|
78
|
+
obj_group.attrs["bbox"] = np.asarray(obj.get("bbox", (0, 0, 0, 0)), dtype=np.int32)
|
|
79
|
+
mask = np.asarray(obj.get("mask", np.zeros((1, 1), dtype=bool)), dtype=np.uint8)
|
|
80
|
+
obj_group.create_dataset(
|
|
81
|
+
"mask",
|
|
82
|
+
data=mask,
|
|
83
|
+
compression="gzip",
|
|
84
|
+
compression_opts=4,
|
|
85
|
+
shuffle=True,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def convert_old_format_to_objects(data: np.ndarray) -> list:
|
|
90
|
+
"""Convert old format mask array to frame_objects list."""
|
|
91
|
+
frames, channels, height, width = data.shape
|
|
92
|
+
frame_objects = []
|
|
93
|
+
for frame_idx in range(frames):
|
|
94
|
+
frame_mask = data[frame_idx, 0, :, :]
|
|
95
|
+
if np.any(frame_mask):
|
|
96
|
+
rows, cols = np.where(frame_mask)
|
|
97
|
+
if len(rows) > 0 and len(cols) > 0:
|
|
98
|
+
y_min, y_max = np.min(rows), np.max(rows)
|
|
99
|
+
x_min, x_max = np.min(cols), np.max(cols)
|
|
100
|
+
obj_mask = frame_mask[y_min:y_max+1, x_min:x_max+1]
|
|
101
|
+
obj = {
|
|
102
|
+
'bbox': (x_min, y_min, x_max, y_max),
|
|
103
|
+
'mask': obj_mask.astype(bool),
|
|
104
|
+
'obj_id': 0
|
|
105
|
+
}
|
|
106
|
+
frame_objects.append([obj])
|
|
107
|
+
else:
|
|
108
|
+
frame_objects.append([])
|
|
109
|
+
else:
|
|
110
|
+
frame_objects.append([])
|
|
111
|
+
return frame_objects
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def read_video_frames(video_path: str, expected_frames: Optional[int] = None) -> list:
|
|
115
|
+
cap = cv2.VideoCapture(video_path)
|
|
116
|
+
if not cap.isOpened():
|
|
117
|
+
raise ValueError(f"Could not open video: {video_path}")
|
|
118
|
+
|
|
119
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
120
|
+
frames = []
|
|
121
|
+
|
|
122
|
+
while True:
|
|
123
|
+
if expected_frames is not None and len(frames) >= expected_frames:
|
|
124
|
+
break
|
|
125
|
+
ret, frame = cap.read()
|
|
126
|
+
if not ret:
|
|
127
|
+
break
|
|
128
|
+
frames.append(frame)
|
|
129
|
+
|
|
130
|
+
cap.release()
|
|
131
|
+
return frames, fps
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def process_frame_with_mask(
|
|
135
|
+
frame_bgr: np.ndarray,
|
|
136
|
+
frame_objects: list,
|
|
137
|
+
frame_idx: int,
|
|
138
|
+
box_size: int = 250,
|
|
139
|
+
target_size: int = 288,
|
|
140
|
+
background_mode: str = 'white',
|
|
141
|
+
normalization_method: str = 'CLAHE',
|
|
142
|
+
mask_feather_px: int = 0,
|
|
143
|
+
anchor_cx: Optional[float] = None,
|
|
144
|
+
anchor_cy: Optional[float] = None,
|
|
145
|
+
anchor_mode: str = 'frame',
|
|
146
|
+
obj_id: Optional[int] = None
|
|
147
|
+
) -> np.ndarray:
|
|
148
|
+
"""
|
|
149
|
+
Process a single frame: crop around centroid and remove background.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
obj_id: If provided, only process this object ID. Otherwise, use first object.
|
|
153
|
+
normalization_method: 'CLAHE', 'Histogram Equalization', 'Mean-Variance', or 'None'
|
|
154
|
+
|
|
155
|
+
Returns processed frame (target_size, target_size, 3) in [0,1] float32.
|
|
156
|
+
"""
|
|
157
|
+
h, w = frame_bgr.shape[:2]
|
|
158
|
+
|
|
159
|
+
if frame_idx >= len(frame_objects) or not frame_objects[frame_idx]:
|
|
160
|
+
return np.zeros((target_size, target_size, 3), dtype=np.float32)
|
|
161
|
+
|
|
162
|
+
obj = None
|
|
163
|
+
if obj_id is not None:
|
|
164
|
+
for o in frame_objects[frame_idx]:
|
|
165
|
+
o_id = o.get('obj_id')
|
|
166
|
+
if str(o_id) == str(obj_id):
|
|
167
|
+
obj = o
|
|
168
|
+
break
|
|
169
|
+
if obj is None:
|
|
170
|
+
return np.zeros((target_size, target_size, 3), dtype=np.float32)
|
|
171
|
+
else:
|
|
172
|
+
obj = frame_objects[frame_idx][0]
|
|
173
|
+
mask = obj['mask']
|
|
174
|
+
x_min, y_min, x_max, y_max = obj['bbox']
|
|
175
|
+
|
|
176
|
+
full_mask = np.zeros((h, w), dtype=np.uint8)
|
|
177
|
+
bbox_h = max(0, y_max - y_min + 1)
|
|
178
|
+
bbox_w = max(0, x_max - x_min + 1)
|
|
179
|
+
|
|
180
|
+
if bbox_h > 0 and bbox_w > 0:
|
|
181
|
+
mh, mw = mask.shape
|
|
182
|
+
h_use = min(bbox_h, mh, h - max(0, y_min))
|
|
183
|
+
w_use = min(bbox_w, mw, w - max(0, x_min))
|
|
184
|
+
if h_use > 0 and w_use > 0:
|
|
185
|
+
try:
|
|
186
|
+
full_mask[y_min:y_min + h_use, x_min:x_min + w_use] = (
|
|
187
|
+
mask[:h_use, :w_use] > 0
|
|
188
|
+
).astype(np.uint8)
|
|
189
|
+
except Exception:
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
m = cv2.moments(full_mask, binaryImage=True)
|
|
193
|
+
if m['m00'] > 0:
|
|
194
|
+
cx = m['m10'] / m['m00']
|
|
195
|
+
cy = m['m01'] / m['m00']
|
|
196
|
+
else:
|
|
197
|
+
cx = (x_min + x_max) / 2.0
|
|
198
|
+
cy = (y_min + y_max) / 2.0
|
|
199
|
+
|
|
200
|
+
if anchor_mode == 'first' and anchor_cx is not None and anchor_cy is not None:
|
|
201
|
+
cx_use, cy_use = anchor_cx, anchor_cy
|
|
202
|
+
else:
|
|
203
|
+
cx_use, cy_use = cx, cy
|
|
204
|
+
|
|
205
|
+
half = box_size // 2
|
|
206
|
+
desired_x1 = int(round(cx_use)) - half
|
|
207
|
+
desired_y1 = int(round(cy_use)) - half
|
|
208
|
+
desired_x2 = desired_x1 + box_size
|
|
209
|
+
desired_y2 = desired_y1 + box_size
|
|
210
|
+
|
|
211
|
+
src_x1 = max(0, desired_x1)
|
|
212
|
+
src_y1 = max(0, desired_y1)
|
|
213
|
+
src_x2 = min(w, desired_x2)
|
|
214
|
+
src_y2 = min(h, desired_y2)
|
|
215
|
+
|
|
216
|
+
if src_x2 <= src_x1 or src_y2 <= src_y1:
|
|
217
|
+
crop_rgb = np.zeros((box_size, box_size, 3), dtype=np.uint8)
|
|
218
|
+
crop_mask = np.zeros((box_size, box_size), dtype=np.uint8)
|
|
219
|
+
else:
|
|
220
|
+
crop_bgr = frame_bgr[src_y1:src_y2, src_x1:src_x2]
|
|
221
|
+
crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
|
|
222
|
+
crop_mask = full_mask[src_y1:src_y2, src_x1:src_x2]
|
|
223
|
+
|
|
224
|
+
pad_left = src_x1 - desired_x1
|
|
225
|
+
pad_top = src_y1 - desired_y1
|
|
226
|
+
pad_right = desired_x2 - src_x2
|
|
227
|
+
pad_bottom = desired_y2 - src_y2
|
|
228
|
+
|
|
229
|
+
if pad_left or pad_top or pad_right or pad_bottom:
|
|
230
|
+
padded_rgb = np.zeros((box_size, box_size, 3), dtype=np.uint8)
|
|
231
|
+
padded_mask = np.zeros((box_size, box_size), dtype=np.uint8)
|
|
232
|
+
h_c, w_c = crop_rgb.shape[:2]
|
|
233
|
+
padded_rgb[pad_top:pad_top + h_c, pad_left:pad_left + w_c] = crop_rgb
|
|
234
|
+
padded_mask[pad_top:pad_top + h_c, pad_left:pad_left + w_c] = crop_mask
|
|
235
|
+
crop_rgb, crop_mask = padded_rgb, padded_mask
|
|
236
|
+
|
|
237
|
+
if normalization_method == 'CLAHE':
|
|
238
|
+
gray = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2GRAY)
|
|
239
|
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
240
|
+
gray_eq = clahe.apply(gray)
|
|
241
|
+
crop_rgb = cv2.cvtColor(gray_eq, cv2.COLOR_GRAY2RGB)
|
|
242
|
+
|
|
243
|
+
elif normalization_method == 'Histogram Equalization':
|
|
244
|
+
gray = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2GRAY)
|
|
245
|
+
gray_eq = cv2.equalizeHist(gray)
|
|
246
|
+
crop_rgb = cv2.cvtColor(gray_eq, cv2.COLOR_GRAY2RGB)
|
|
247
|
+
|
|
248
|
+
elif normalization_method == 'Mean-Variance':
|
|
249
|
+
gray = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2GRAY).astype(np.float32)
|
|
250
|
+
mean, std = cv2.meanStdDev(gray)
|
|
251
|
+
if std[0][0] > 0:
|
|
252
|
+
gray = (gray - mean[0][0]) / std[0][0] * 50 + 127
|
|
253
|
+
else:
|
|
254
|
+
gray = gray - mean[0][0] + 127
|
|
255
|
+
gray = np.clip(gray, 0, 255).astype(np.uint8)
|
|
256
|
+
crop_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
|
|
257
|
+
|
|
258
|
+
if crop_rgb.shape[0] != box_size or crop_rgb.shape[1] != box_size:
|
|
259
|
+
pad_img = np.zeros((box_size, box_size, 3), dtype=np.uint8)
|
|
260
|
+
h_c, w_c = crop_rgb.shape[:2]
|
|
261
|
+
pad_img[:h_c, :w_c] = crop_rgb
|
|
262
|
+
crop_rgb = pad_img
|
|
263
|
+
|
|
264
|
+
pad_mask = np.zeros((box_size, box_size), dtype=np.uint8)
|
|
265
|
+
if 'crop_mask' in locals():
|
|
266
|
+
pad_mask[:h_c, :w_c] = crop_mask
|
|
267
|
+
crop_mask = pad_mask
|
|
268
|
+
|
|
269
|
+
if background_mode in ('black', 'gray', 'blur', 'white'):
|
|
270
|
+
m_float = crop_mask.astype(np.float32)
|
|
271
|
+
if mask_feather_px and mask_feather_px > 0:
|
|
272
|
+
k = max(1, int(mask_feather_px) | 1)
|
|
273
|
+
m_float = cv2.GaussianBlur(m_float, (k, k), 0)
|
|
274
|
+
m_float = np.clip(m_float, 0.0, 1.0)
|
|
275
|
+
m3 = np.repeat(m_float[:, :, None], 3, axis=2)
|
|
276
|
+
|
|
277
|
+
if background_mode == 'black':
|
|
278
|
+
bg = np.zeros_like(crop_rgb)
|
|
279
|
+
elif background_mode == 'gray':
|
|
280
|
+
bg = np.full_like(crop_rgb, 128)
|
|
281
|
+
elif background_mode == 'white':
|
|
282
|
+
bg = np.full_like(crop_rgb, 255)
|
|
283
|
+
else: # blur
|
|
284
|
+
bg = cv2.GaussianBlur(crop_rgb, (11, 11), 0)
|
|
285
|
+
|
|
286
|
+
crop_rgb = (m3 * crop_rgb + (1.0 - m3) * bg).astype(np.uint8)
|
|
287
|
+
|
|
288
|
+
# Resize to target size
|
|
289
|
+
frame_rgb = cv2.resize(crop_rgb, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
|
290
|
+
|
|
291
|
+
# Normalize to [0, 1]
|
|
292
|
+
return (frame_rgb.astype(np.float32) / 255.0)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def process_video(
|
|
296
|
+
video_path: str,
|
|
297
|
+
mask_path: str,
|
|
298
|
+
output_path: str,
|
|
299
|
+
box_size: int = 250,
|
|
300
|
+
target_size: int = 288,
|
|
301
|
+
background_mode: str = 'white',
|
|
302
|
+
mask_feather_px: int = 0,
|
|
303
|
+
anchor_mode: str = 'first',
|
|
304
|
+
progress_callback: Optional[callable] = None,
|
|
305
|
+
obj_id: Optional[int] = None
|
|
306
|
+
) -> bool:
|
|
307
|
+
"""
|
|
308
|
+
Process entire video: load mask, process each frame, save output.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
video_path: Path to input video
|
|
312
|
+
mask_path: Path to HDF5 mask file
|
|
313
|
+
output_path: Path to save processed video (or base path if obj_id is None and multiple objects exist)
|
|
314
|
+
box_size: Size of crop box in pixels
|
|
315
|
+
target_size: Final output size
|
|
316
|
+
background_mode: 'white', 'black', 'gray', 'blur', 'none'
|
|
317
|
+
mask_feather_px: Feathering radius for mask edges
|
|
318
|
+
anchor_mode: 'frame' (per-frame centroid) or 'first' (fixed at first frame)
|
|
319
|
+
progress_callback: Optional function(frame_num, total_frames, obj_id) for progress updates
|
|
320
|
+
obj_id: If provided, only process this object ID. If None and multiple objects exist, creates separate videos.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
True if successful, False otherwise. If multiple objects, returns list of output paths.
|
|
324
|
+
"""
|
|
325
|
+
try:
|
|
326
|
+
mask_data = load_segmentation_data(mask_path)
|
|
327
|
+
frame_objects = mask_data['frame_objects']
|
|
328
|
+
num_frames = len(frame_objects)
|
|
329
|
+
start_offset = mask_data.get('start_offset', 0)
|
|
330
|
+
|
|
331
|
+
all_obj_ids = set()
|
|
332
|
+
for frame_objs in frame_objects:
|
|
333
|
+
for obj in frame_objs:
|
|
334
|
+
obj_id = obj.get('obj_id', 0)
|
|
335
|
+
all_obj_ids.add(obj_id)
|
|
336
|
+
all_obj_ids = sorted(list(all_obj_ids))
|
|
337
|
+
|
|
338
|
+
if obj_id is None and len(all_obj_ids) > 1:
|
|
339
|
+
output_paths = []
|
|
340
|
+
base_path = os.path.splitext(output_path)[0]
|
|
341
|
+
ext = os.path.splitext(output_path)[1]
|
|
342
|
+
|
|
343
|
+
for oid in all_obj_ids:
|
|
344
|
+
obj_output_path = f"{base_path}_obj{oid}{ext}"
|
|
345
|
+
success = process_video(
|
|
346
|
+
video_path, mask_path, obj_output_path,
|
|
347
|
+
box_size, target_size, background_mode,
|
|
348
|
+
mask_feather_px, anchor_mode, progress_callback, obj_id=oid
|
|
349
|
+
)
|
|
350
|
+
if success:
|
|
351
|
+
output_paths.append(obj_output_path)
|
|
352
|
+
|
|
353
|
+
return output_paths if output_paths else False
|
|
354
|
+
|
|
355
|
+
if obj_id is None:
|
|
356
|
+
obj_id = all_obj_ids[0] if all_obj_ids else None
|
|
357
|
+
|
|
358
|
+
raw_frames, fps = read_video_frames(video_path, expected_frames=num_frames)
|
|
359
|
+
if len(raw_frames) != num_frames:
|
|
360
|
+
logger.warning("Video has %d frames, mask has %d", len(raw_frames), num_frames)
|
|
361
|
+
num_frames = min(len(raw_frames), num_frames)
|
|
362
|
+
|
|
363
|
+
anchor_cx = None
|
|
364
|
+
anchor_cy = None
|
|
365
|
+
if anchor_mode == 'first':
|
|
366
|
+
for frame_idx in range(min(10, num_frames)): # Check first 10 frames
|
|
367
|
+
if frame_idx < len(frame_objects) and frame_objects[frame_idx]:
|
|
368
|
+
obj0 = None
|
|
369
|
+
for o in frame_objects[frame_idx]:
|
|
370
|
+
if o.get('obj_id') == obj_id:
|
|
371
|
+
obj0 = o
|
|
372
|
+
break
|
|
373
|
+
if obj0 is None:
|
|
374
|
+
continue
|
|
375
|
+
|
|
376
|
+
mask0 = obj0['mask']
|
|
377
|
+
x0_min, y0_min, x0_max, y0_max = obj0['bbox']
|
|
378
|
+
frame_bgr0 = raw_frames[frame_idx]
|
|
379
|
+
h0, w0 = frame_bgr0.shape[:2]
|
|
380
|
+
full_mask0 = np.zeros((h0, w0), dtype=np.uint8)
|
|
381
|
+
bbox_h0 = max(0, y0_max - y0_min + 1)
|
|
382
|
+
bbox_w0 = max(0, x0_max - x0_min + 1)
|
|
383
|
+
if bbox_h0 > 0 and bbox_w0 > 0:
|
|
384
|
+
mh0, mw0 = mask0.shape
|
|
385
|
+
h0_use = min(bbox_h0, mh0, h0 - max(0, y0_min))
|
|
386
|
+
w0_use = min(bbox_w0, mw0, w0 - max(0, x0_min))
|
|
387
|
+
if h0_use > 0 and w0_use > 0:
|
|
388
|
+
try:
|
|
389
|
+
full_mask0[y0_min:y0_min + h0_use, x0_min:x0_min + w0_use] = (
|
|
390
|
+
mask0[:h0_use, :w0_use] > 0
|
|
391
|
+
).astype(np.uint8)
|
|
392
|
+
except Exception:
|
|
393
|
+
pass
|
|
394
|
+
m0 = cv2.moments(full_mask0, binaryImage=True)
|
|
395
|
+
if m0['m00'] > 0:
|
|
396
|
+
anchor_cx = m0['m10'] / m0['m00']
|
|
397
|
+
anchor_cy = m0['m01'] / m0['m00']
|
|
398
|
+
break
|
|
399
|
+
else:
|
|
400
|
+
anchor_cx = (x0_min + x0_max) / 2.0
|
|
401
|
+
anchor_cy = (y0_min + y0_max) / 2.0
|
|
402
|
+
break
|
|
403
|
+
|
|
404
|
+
# Use mp4v codec for efficient encoding
|
|
405
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
406
|
+
out = cv2.VideoWriter(output_path, fourcc, fps, (target_size, target_size), isColor=True)
|
|
407
|
+
|
|
408
|
+
if not out.isOpened():
|
|
409
|
+
raise ValueError(f"Failed to open video writer for {output_path}")
|
|
410
|
+
|
|
411
|
+
try:
|
|
412
|
+
for frame_idx in range(num_frames):
|
|
413
|
+
if progress_callback:
|
|
414
|
+
if callable(progress_callback):
|
|
415
|
+
try:
|
|
416
|
+
progress_callback(frame_idx + 1, num_frames, obj_id)
|
|
417
|
+
except TypeError:
|
|
418
|
+
progress_callback(frame_idx + 1, num_frames)
|
|
419
|
+
|
|
420
|
+
frame_bgr = raw_frames[frame_idx]
|
|
421
|
+
|
|
422
|
+
frame_processed = process_frame_with_mask(
|
|
423
|
+
frame_bgr,
|
|
424
|
+
frame_objects,
|
|
425
|
+
frame_idx,
|
|
426
|
+
box_size=box_size,
|
|
427
|
+
target_size=target_size,
|
|
428
|
+
background_mode=background_mode,
|
|
429
|
+
mask_feather_px=mask_feather_px,
|
|
430
|
+
anchor_cx=anchor_cx,
|
|
431
|
+
anchor_cy=anchor_cy,
|
|
432
|
+
anchor_mode=anchor_mode,
|
|
433
|
+
obj_id=obj_id
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
frame_bgr_out = (frame_processed * 255.0).astype(np.uint8)
|
|
437
|
+
frame_bgr_out = cv2.cvtColor(frame_bgr_out, cv2.COLOR_RGB2BGR)
|
|
438
|
+
out.write(frame_bgr_out)
|
|
439
|
+
|
|
440
|
+
return True
|
|
441
|
+
finally:
|
|
442
|
+
out.release()
|
|
443
|
+
|
|
444
|
+
except Exception as e:
|
|
445
|
+
logger.error("Error processing video: %s", e, exc_info=True)
|
|
446
|
+
return False
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def process_video_to_clips(
|
|
450
|
+
video_path: str,
|
|
451
|
+
mask_path: str,
|
|
452
|
+
output_dir: str,
|
|
453
|
+
box_size: int = 250,
|
|
454
|
+
target_size: int = 288,
|
|
455
|
+
background_mode: str = 'white',
|
|
456
|
+
normalization_method: str = 'CLAHE',
|
|
457
|
+
mask_feather_px: int = 0,
|
|
458
|
+
anchor_mode: str = 'first',
|
|
459
|
+
target_fps: int = 16,
|
|
460
|
+
clip_length_frames: int = 16,
|
|
461
|
+
step_frames: int = 16,
|
|
462
|
+
progress_callback: Optional[callable] = None,
|
|
463
|
+
obj_id: Optional[int] = None
|
|
464
|
+
) -> list:
|
|
465
|
+
"""
|
|
466
|
+
Process video into clips with registration: load mask, process frames, save as clips.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
video_path: Path to input video
|
|
470
|
+
mask_path: Path to HDF5 mask file
|
|
471
|
+
output_dir: Directory to save clips
|
|
472
|
+
box_size: Size of crop box in pixels
|
|
473
|
+
target_size: Final output size
|
|
474
|
+
background_mode: 'white', 'black', 'gray', 'blur', 'none'
|
|
475
|
+
normalization_method: 'CLAHE', 'Histogram Equalization', 'Mean-Variance', or 'None'
|
|
476
|
+
mask_feather_px: Feathering radius for mask edges
|
|
477
|
+
anchor_mode: 'frame' (per-frame centroid) or 'first' (fixed at first frame)
|
|
478
|
+
target_fps: Target FPS for clips (frames will be subsampled)
|
|
479
|
+
clip_length_frames: Number of frames per clip
|
|
480
|
+
step_frames: Step size between clips
|
|
481
|
+
progress_callback: Optional function(clip_num, total_clips, obj_id) for progress updates
|
|
482
|
+
obj_id: If provided, only process this object ID. If None and multiple objects exist, creates separate clips.
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
List of tuples: (clip_path, start_frame, end_frame). If multiple objects, returns list of lists.
|
|
486
|
+
"""
|
|
487
|
+
try:
|
|
488
|
+
# Load mask data
|
|
489
|
+
mask_data = load_segmentation_data(mask_path)
|
|
490
|
+
frame_objects = mask_data['frame_objects']
|
|
491
|
+
num_frames = len(frame_objects)
|
|
492
|
+
# Get start_offset if masks were trimmed to a range
|
|
493
|
+
start_offset = mask_data.get('start_offset', 0)
|
|
494
|
+
|
|
495
|
+
all_obj_ids = set()
|
|
496
|
+
for frame_objs in frame_objects:
|
|
497
|
+
for obj in frame_objs:
|
|
498
|
+
oid = obj.get('obj_id', 0)
|
|
499
|
+
try:
|
|
500
|
+
all_obj_ids.add(int(oid))
|
|
501
|
+
except (ValueError, TypeError):
|
|
502
|
+
all_obj_ids.add(oid)
|
|
503
|
+
try:
|
|
504
|
+
all_obj_ids = sorted(list(all_obj_ids), key=lambda x: int(x))
|
|
505
|
+
except:
|
|
506
|
+
all_obj_ids = sorted(list(all_obj_ids), key=str)
|
|
507
|
+
|
|
508
|
+
if obj_id is None and len(all_obj_ids) > 1:
|
|
509
|
+
logger.info("Processing %d objects: %s", len(all_obj_ids), all_obj_ids)
|
|
510
|
+
all_clip_paths = []
|
|
511
|
+
for oid in all_obj_ids:
|
|
512
|
+
logger.info("Starting processing for object %s...", oid)
|
|
513
|
+
try:
|
|
514
|
+
clip_paths = process_video_to_clips(
|
|
515
|
+
video_path, mask_path, output_dir,
|
|
516
|
+
box_size, target_size, background_mode,
|
|
517
|
+
normalization_method,
|
|
518
|
+
mask_feather_px, anchor_mode,
|
|
519
|
+
target_fps, clip_length_frames, step_frames,
|
|
520
|
+
progress_callback, obj_id=oid
|
|
521
|
+
)
|
|
522
|
+
all_clip_paths.extend(clip_paths)
|
|
523
|
+
logger.info("Completed object %s: %d clips generated.", oid, len(clip_paths))
|
|
524
|
+
except Exception as e:
|
|
525
|
+
logger.error("Error processing object %s: %s", oid, e, exc_info=True)
|
|
526
|
+
return all_clip_paths
|
|
527
|
+
|
|
528
|
+
if obj_id is None:
|
|
529
|
+
obj_id = all_obj_ids[0] if all_obj_ids else None
|
|
530
|
+
|
|
531
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
532
|
+
|
|
533
|
+
cap = cv2.VideoCapture(video_path)
|
|
534
|
+
if not cap.isOpened():
|
|
535
|
+
raise ValueError(f"Could not open video: {video_path}")
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
if start_offset > 0:
|
|
539
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, start_offset)
|
|
540
|
+
|
|
541
|
+
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
542
|
+
if orig_fps <= 0:
|
|
543
|
+
orig_fps = 30.0
|
|
544
|
+
|
|
545
|
+
frame_interval = max(1, int(round(orig_fps / target_fps)))
|
|
546
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, start_offset)
|
|
547
|
+
|
|
548
|
+
clip_paths = [] # (clip_path, start_frame, end_frame)
|
|
549
|
+
frame_idx = start_offset # Original video frame index
|
|
550
|
+
clip_idx = 0
|
|
551
|
+
frames_buffer = []
|
|
552
|
+
skip_remaining = 0
|
|
553
|
+
frame_indices_in_buffer = [] # Track original frame indices for each frame in buffer
|
|
554
|
+
clip_start_mask_frame_idx = None # Track first frame (original index)
|
|
555
|
+
|
|
556
|
+
# Persistent anchor for current clip (used when anchor_mode='first')
|
|
557
|
+
clip_anchor_cx = None
|
|
558
|
+
clip_anchor_cy = None
|
|
559
|
+
|
|
560
|
+
end_frame_limit = start_offset + num_frames # process only the trimmed range
|
|
561
|
+
|
|
562
|
+
while True:
|
|
563
|
+
if frame_idx >= end_frame_limit:
|
|
564
|
+
break
|
|
565
|
+
|
|
566
|
+
ret, frame_bgr = cap.read()
|
|
567
|
+
if not ret:
|
|
568
|
+
break
|
|
569
|
+
|
|
570
|
+
if frame_idx % frame_interval == 0:
|
|
571
|
+
if skip_remaining > 0:
|
|
572
|
+
skip_remaining -= 1
|
|
573
|
+
frame_idx += 1
|
|
574
|
+
continue
|
|
575
|
+
mask_frame_idx = frame_idx - start_offset
|
|
576
|
+
if mask_frame_idx < 0 or mask_frame_idx >= len(frame_objects):
|
|
577
|
+
frame_idx += 1
|
|
578
|
+
continue
|
|
579
|
+
|
|
580
|
+
if len(frames_buffer) == 0:
|
|
581
|
+
clip_start_mask_frame_idx = mask_frame_idx
|
|
582
|
+
clip_anchor_cx = None
|
|
583
|
+
clip_anchor_cy = None
|
|
584
|
+
|
|
585
|
+
if anchor_mode == 'first' and mask_frame_idx < len(frame_objects) and frame_objects[mask_frame_idx]:
|
|
586
|
+
obj0 = None
|
|
587
|
+
for o in frame_objects[mask_frame_idx]:
|
|
588
|
+
o_id = o.get('obj_id')
|
|
589
|
+
if str(o_id) == str(obj_id):
|
|
590
|
+
obj0 = o
|
|
591
|
+
break
|
|
592
|
+
if obj0:
|
|
593
|
+
mask0 = obj0['mask']
|
|
594
|
+
x0_min, y0_min, x0_max, y0_max = obj0['bbox']
|
|
595
|
+
h0, w0 = frame_bgr.shape[:2]
|
|
596
|
+
full_mask0 = np.zeros((h0, w0), dtype=np.uint8)
|
|
597
|
+
bbox_h0 = max(0, y0_max - y0_min + 1)
|
|
598
|
+
bbox_w0 = max(0, x0_max - x0_min + 1)
|
|
599
|
+
if bbox_h0 > 0 and bbox_w0 > 0:
|
|
600
|
+
mh0, mw0 = mask0.shape
|
|
601
|
+
h0_use = min(bbox_h0, mh0, h0 - max(0, y0_min))
|
|
602
|
+
w0_use = min(bbox_w0, mw0, w0 - max(0, x0_min))
|
|
603
|
+
if h0_use > 0 and w0_use > 0:
|
|
604
|
+
try:
|
|
605
|
+
full_mask0[y0_min:y0_min + h0_use, x0_min:x0_min + w0_use] = (
|
|
606
|
+
mask0[:h0_use, :w0_use] > 0
|
|
607
|
+
).astype(np.uint8)
|
|
608
|
+
except Exception:
|
|
609
|
+
pass
|
|
610
|
+
m0 = cv2.moments(full_mask0, binaryImage=True)
|
|
611
|
+
if m0['m00'] > 0:
|
|
612
|
+
clip_anchor_cx = m0['m10'] / m0['m00']
|
|
613
|
+
clip_anchor_cy = m0['m01'] / m0['m00']
|
|
614
|
+
else:
|
|
615
|
+
clip_anchor_cx = (x0_min + x0_max) / 2.0
|
|
616
|
+
clip_anchor_cy = (y0_min + y0_max) / 2.0
|
|
617
|
+
|
|
618
|
+
frame_processed = process_frame_with_mask(
|
|
619
|
+
frame_bgr,
|
|
620
|
+
frame_objects,
|
|
621
|
+
mask_frame_idx,
|
|
622
|
+
box_size=box_size,
|
|
623
|
+
target_size=target_size,
|
|
624
|
+
background_mode=background_mode,
|
|
625
|
+
normalization_method=normalization_method,
|
|
626
|
+
mask_feather_px=mask_feather_px,
|
|
627
|
+
anchor_cx=clip_anchor_cx,
|
|
628
|
+
anchor_cy=clip_anchor_cy,
|
|
629
|
+
anchor_mode=anchor_mode,
|
|
630
|
+
obj_id=obj_id
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
frame_bgr_out = (frame_processed * 255.0).astype(np.uint8)
|
|
634
|
+
frame_bgr_out = cv2.cvtColor(frame_bgr_out, cv2.COLOR_RGB2BGR)
|
|
635
|
+
|
|
636
|
+
frames_buffer.append(frame_bgr_out)
|
|
637
|
+
frame_indices_in_buffer.append(frame_idx)
|
|
638
|
+
|
|
639
|
+
if len(frames_buffer) == clip_length_frames:
|
|
640
|
+
video_basename = os.path.splitext(os.path.basename(video_path))[0]
|
|
641
|
+
clip_name = f"{video_basename}_clip_{clip_idx:06d}"
|
|
642
|
+
if obj_id is not None:
|
|
643
|
+
clip_name += f"_obj{obj_id}"
|
|
644
|
+
clip_path = os.path.join(output_dir, f"{clip_name}.mp4")
|
|
645
|
+
|
|
646
|
+
clip_start_mask_frame_idx = frame_indices_in_buffer[0]
|
|
647
|
+
clip_end_mask_frame_idx = frame_indices_in_buffer[-1]
|
|
648
|
+
|
|
649
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
650
|
+
out = cv2.VideoWriter(clip_path, fourcc, target_fps, (target_size, target_size), isColor=True)
|
|
651
|
+
|
|
652
|
+
if out.isOpened():
|
|
653
|
+
for f in frames_buffer:
|
|
654
|
+
out.write(f)
|
|
655
|
+
out.release()
|
|
656
|
+
clip_paths.append((clip_path, clip_start_mask_frame_idx, clip_end_mask_frame_idx))
|
|
657
|
+
|
|
658
|
+
if progress_callback:
|
|
659
|
+
try:
|
|
660
|
+
progress_callback(clip_idx + 1, None, obj_id)
|
|
661
|
+
except TypeError:
|
|
662
|
+
try:
|
|
663
|
+
progress_callback(clip_idx + 1, None)
|
|
664
|
+
except TypeError:
|
|
665
|
+
pass
|
|
666
|
+
|
|
667
|
+
clip_idx += 1
|
|
668
|
+
|
|
669
|
+
if step_frames < clip_length_frames:
|
|
670
|
+
frames_buffer = frames_buffer[clip_length_frames - step_frames:]
|
|
671
|
+
frame_indices_in_buffer = frame_indices_in_buffer[clip_length_frames - step_frames:]
|
|
672
|
+
clip_start_mask_frame_idx = frame_indices_in_buffer[0] if frame_indices_in_buffer else None
|
|
673
|
+
else:
|
|
674
|
+
frames_buffer = []
|
|
675
|
+
frame_indices_in_buffer = []
|
|
676
|
+
clip_start_mask_frame_idx = None
|
|
677
|
+
skip_remaining = max(0, step_frames - clip_length_frames)
|
|
678
|
+
|
|
679
|
+
frame_idx += 1
|
|
680
|
+
|
|
681
|
+
return clip_paths
|
|
682
|
+
finally:
|
|
683
|
+
cap.release()
|
|
684
|
+
|
|
685
|
+
except Exception as e:
|
|
686
|
+
logger.error("Error processing video to clips: %s", e, exc_info=True)
|
|
687
|
+
return []
|
|
688
|
+
|