singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
sam2/utils/transforms.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torchvision.transforms import Normalize, Resize, ToTensor
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SAM2Transforms(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Transforms for SAM2.
|
|
21
|
+
"""
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.resolution = resolution
|
|
24
|
+
self.mask_threshold = mask_threshold
|
|
25
|
+
self.max_hole_area = max_hole_area
|
|
26
|
+
self.max_sprinkle_area = max_sprinkle_area
|
|
27
|
+
self.mean = [0.485, 0.456, 0.406]
|
|
28
|
+
self.std = [0.229, 0.224, 0.225]
|
|
29
|
+
self.to_tensor = ToTensor()
|
|
30
|
+
self.transforms = torch.jit.script(
|
|
31
|
+
nn.Sequential(
|
|
32
|
+
Resize((self.resolution, self.resolution)),
|
|
33
|
+
Normalize(self.mean, self.std),
|
|
34
|
+
)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def __call__(self, x):
|
|
38
|
+
x = self.to_tensor(x)
|
|
39
|
+
return self.transforms(x)
|
|
40
|
+
|
|
41
|
+
def forward_batch(self, img_list):
|
|
42
|
+
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
|
|
43
|
+
img_batch = torch.stack(img_batch, dim=0)
|
|
44
|
+
return img_batch
|
|
45
|
+
|
|
46
|
+
def transform_coords(
|
|
47
|
+
self, coords: torch.Tensor, normalize=False, orig_hw=None
|
|
48
|
+
) -> torch.Tensor:
|
|
49
|
+
"""
|
|
50
|
+
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
|
|
51
|
+
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
|
|
55
|
+
"""
|
|
56
|
+
if normalize:
|
|
57
|
+
assert orig_hw is not None
|
|
58
|
+
h, w = orig_hw
|
|
59
|
+
coords = coords.clone()
|
|
60
|
+
coords[..., 0] = coords[..., 0] / w
|
|
61
|
+
coords[..., 1] = coords[..., 1] / h
|
|
62
|
+
|
|
63
|
+
coords = coords * self.resolution # unnormalize coords
|
|
64
|
+
return coords
|
|
65
|
+
|
|
66
|
+
def transform_boxes(
|
|
67
|
+
self, boxes: torch.Tensor, normalize=False, orig_hw=None
|
|
68
|
+
) -> torch.Tensor:
|
|
69
|
+
"""
|
|
70
|
+
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
|
|
71
|
+
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
|
|
72
|
+
"""
|
|
73
|
+
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
|
|
74
|
+
return boxes
|
|
75
|
+
|
|
76
|
+
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
|
|
77
|
+
"""
|
|
78
|
+
Perform PostProcessing on output masks.
|
|
79
|
+
"""
|
|
80
|
+
from sam2.utils.misc import get_connected_components
|
|
81
|
+
|
|
82
|
+
masks = masks.float()
|
|
83
|
+
input_masks = masks
|
|
84
|
+
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
|
85
|
+
try:
|
|
86
|
+
if self.max_hole_area > 0:
|
|
87
|
+
# Holes are those connected components in background with area <= self.fill_hole_area
|
|
88
|
+
# (background regions are those with mask scores <= self.mask_threshold)
|
|
89
|
+
labels, areas = get_connected_components(
|
|
90
|
+
mask_flat <= self.mask_threshold
|
|
91
|
+
)
|
|
92
|
+
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
|
93
|
+
is_hole = is_hole.reshape_as(masks)
|
|
94
|
+
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
|
95
|
+
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
|
96
|
+
|
|
97
|
+
if self.max_sprinkle_area > 0:
|
|
98
|
+
labels, areas = get_connected_components(
|
|
99
|
+
mask_flat > self.mask_threshold
|
|
100
|
+
)
|
|
101
|
+
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
|
102
|
+
is_hole = is_hole.reshape_as(masks)
|
|
103
|
+
# We fill holes with negative mask score (-10.0) to change them to background.
|
|
104
|
+
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
|
105
|
+
except Exception as e:
|
|
106
|
+
# Skip the post-processing step if the CUDA kernel fails
|
|
107
|
+
warnings.warn(
|
|
108
|
+
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
|
109
|
+
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
|
110
|
+
"functionality may be limited (which doesn't affect the results in most cases; see "
|
|
111
|
+
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
|
|
112
|
+
category=UserWarning,
|
|
113
|
+
stacklevel=2,
|
|
114
|
+
)
|
|
115
|
+
masks = input_masks
|
|
116
|
+
|
|
117
|
+
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
|
118
|
+
return masks
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Main entry point for SingleBehaviorLab.
|
|
4
|
+
Runs when invoked as:
|
|
5
|
+
python -m singlebehaviorlab
|
|
6
|
+
singlebehaviorlab (pip entry point)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import sys
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
|
|
14
|
+
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
|
|
15
|
+
os.environ.setdefault("GRPC_VERBOSITY", "ERROR")
|
|
16
|
+
os.environ.setdefault("GLOG_minloglevel", "2")
|
|
17
|
+
|
|
18
|
+
# Let JAX grow GPU memory on demand and leave headroom for PyTorch.
|
|
19
|
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
20
|
+
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.45"
|
|
21
|
+
# Fall back to driver JIT compilation when ptxas/nvlink is unavailable.
|
|
22
|
+
os.environ["XLA_FLAGS"] = "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found"
|
|
23
|
+
|
|
24
|
+
import yaml
|
|
25
|
+
from singlebehaviorlab._paths import get_default_config_path, get_experiments_dir
|
|
26
|
+
from singlebehaviorlab.gui.main_window import MainWindow
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def load_config(config_path: str = None) -> dict:
|
|
30
|
+
"""Load configuration from YAML file."""
|
|
31
|
+
if config_path is None:
|
|
32
|
+
config_path = str(get_default_config_path())
|
|
33
|
+
|
|
34
|
+
if os.path.exists(config_path):
|
|
35
|
+
with open(config_path, 'r') as f:
|
|
36
|
+
config = yaml.safe_load(f) or {}
|
|
37
|
+
else:
|
|
38
|
+
config = {}
|
|
39
|
+
|
|
40
|
+
# For path keys that are missing or blank, resolve relative to the
|
|
41
|
+
# experiments directory (pip install) or the package parent (source install).
|
|
42
|
+
from singlebehaviorlab._paths import get_package_dir
|
|
43
|
+
base_dir = str(get_package_dir().parent)
|
|
44
|
+
|
|
45
|
+
defaults = {
|
|
46
|
+
"data_dir": os.path.join(base_dir, "data"),
|
|
47
|
+
"raw_videos_dir": os.path.join(base_dir, "data", "raw_videos"),
|
|
48
|
+
"clips_dir": os.path.join(base_dir, "data", "clips"),
|
|
49
|
+
"annotations_dir": os.path.join(base_dir, "data", "annotations"),
|
|
50
|
+
"models_dir": os.path.join(base_dir, "models", "behavior_heads"),
|
|
51
|
+
"backbone_dir": os.path.join(base_dir, "models", "videoprism_backbone"),
|
|
52
|
+
"annotation_file": os.path.join(base_dir, "data", "annotations", "annotations.json"),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
for key, value in defaults.items():
|
|
56
|
+
if not config.get(key):
|
|
57
|
+
config[key] = value
|
|
58
|
+
elif not os.path.isabs(config[key]):
|
|
59
|
+
config[key] = os.path.join(base_dir, config[key])
|
|
60
|
+
|
|
61
|
+
experiments_dir = str(get_experiments_dir())
|
|
62
|
+
if not config.get("experiments_dir"):
|
|
63
|
+
config["experiments_dir"] = experiments_dir
|
|
64
|
+
os.makedirs(config["experiments_dir"], exist_ok=True)
|
|
65
|
+
|
|
66
|
+
config["config_path"] = config_path
|
|
67
|
+
config.setdefault("experiment_name", None)
|
|
68
|
+
config.setdefault("experiment_path", None)
|
|
69
|
+
|
|
70
|
+
return config
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def main():
|
|
74
|
+
"""Application entry point."""
|
|
75
|
+
logging.basicConfig(
|
|
76
|
+
level=logging.WARNING,
|
|
77
|
+
format="%(levelname)s [%(name)s] %(message)s",
|
|
78
|
+
)
|
|
79
|
+
from PyQt6.QtWidgets import (
|
|
80
|
+
QApplication, QDialog, QVBoxLayout, QPushButton, QLabel,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
app = QApplication(sys.argv)
|
|
84
|
+
app.setApplicationName("SingleBehaviorLab")
|
|
85
|
+
|
|
86
|
+
config = load_config()
|
|
87
|
+
|
|
88
|
+
startup_dialog = QDialog()
|
|
89
|
+
startup_dialog.setWindowTitle("Welcome - Experiment Management")
|
|
90
|
+
startup_dialog.setMinimumSize(400, 200)
|
|
91
|
+
startup_dialog.setModal(True)
|
|
92
|
+
|
|
93
|
+
layout = QVBoxLayout()
|
|
94
|
+
|
|
95
|
+
welcome_label = QLabel(
|
|
96
|
+
"<h2>Welcome to SingleBehaviorLab</h2>"
|
|
97
|
+
"<p>Please choose an option to get started:</p>"
|
|
98
|
+
)
|
|
99
|
+
welcome_label.setWordWrap(True)
|
|
100
|
+
layout.addWidget(welcome_label)
|
|
101
|
+
|
|
102
|
+
create_btn = QPushButton("Create New Experiment")
|
|
103
|
+
create_btn.setMinimumHeight(40)
|
|
104
|
+
create_btn.setStyleSheet("font-size: 12px; font-weight: bold;")
|
|
105
|
+
create_btn.clicked.connect(startup_dialog.accept)
|
|
106
|
+
layout.addWidget(create_btn)
|
|
107
|
+
|
|
108
|
+
load_btn = QPushButton("Load Existing Experiment")
|
|
109
|
+
load_btn.setMinimumHeight(40)
|
|
110
|
+
load_btn.setStyleSheet("font-size: 12px; font-weight: bold;")
|
|
111
|
+
load_btn.clicked.connect(startup_dialog.reject)
|
|
112
|
+
layout.addWidget(load_btn)
|
|
113
|
+
|
|
114
|
+
startup_dialog.setLayout(layout)
|
|
115
|
+
|
|
116
|
+
result = startup_dialog.exec()
|
|
117
|
+
|
|
118
|
+
window = MainWindow(config)
|
|
119
|
+
window.show()
|
|
120
|
+
|
|
121
|
+
if result == QDialog.DialogCode.Accepted:
|
|
122
|
+
window._create_experiment()
|
|
123
|
+
elif result == QDialog.DialogCode.Rejected:
|
|
124
|
+
window._load_experiment()
|
|
125
|
+
|
|
126
|
+
sys.exit(app.exec())
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
if __name__ == "__main__":
|
|
130
|
+
main()
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Central path resolver for SingleBehaviorLab.
|
|
3
|
+
|
|
4
|
+
Handles two install modes transparently:
|
|
5
|
+
|
|
6
|
+
Source / zip distribution
|
|
7
|
+
SingleBehaviorLab/
|
|
8
|
+
singlebehaviorlab/ ← this file lives here
|
|
9
|
+
sam2_backend/
|
|
10
|
+
sam2_checkpoints/
|
|
11
|
+
experiments/
|
|
12
|
+
|
|
13
|
+
pip install (site-packages)
|
|
14
|
+
site-packages/singlebehaviorlab/ ← this file lives here
|
|
15
|
+
~/SingleBehaviorLab/ ← user data lives here
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
# Directory where this file (and the package) lives.
|
|
21
|
+
_PKG_DIR = Path(__file__).parent
|
|
22
|
+
|
|
23
|
+
# One level up: the SingleBehaviorLab root when running from source/zip,
|
|
24
|
+
# or site-packages when pip-installed.
|
|
25
|
+
_PKG_PARENT = _PKG_DIR.parent
|
|
26
|
+
|
|
27
|
+
# Standard user data directory — always writable.
|
|
28
|
+
USER_DATA_DIR = Path.home() / "SingleBehaviorLab"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _first_existing(*candidates: Path) -> Path:
|
|
32
|
+
"""Return the first candidate path that exists, otherwise the first candidate."""
|
|
33
|
+
for p in candidates:
|
|
34
|
+
if p.exists():
|
|
35
|
+
return p
|
|
36
|
+
return candidates[0]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Public helpers.
|
|
40
|
+
|
|
41
|
+
def get_package_dir() -> Path:
|
|
42
|
+
"""The installed singlebehaviorlab package directory."""
|
|
43
|
+
return _PKG_DIR
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_sam2_backend_dir() -> Path:
|
|
47
|
+
"""Return the directory containing the sam2 Python package."""
|
|
48
|
+
try:
|
|
49
|
+
import sam2 # type: ignore
|
|
50
|
+
sam2_pkg = Path(sam2.__file__).resolve().parent
|
|
51
|
+
if (sam2_pkg / "configs").exists():
|
|
52
|
+
return sam2_pkg.parent
|
|
53
|
+
except Exception:
|
|
54
|
+
pass
|
|
55
|
+
return _first_existing(
|
|
56
|
+
_PKG_PARENT / "sam2_backend",
|
|
57
|
+
_PKG_DIR / "sam2_backend",
|
|
58
|
+
USER_DATA_DIR / "sam2_backend",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_sam2_checkpoints_dir() -> Path:
|
|
63
|
+
"""Return the directory where SAM2 checkpoint files live, creating it if absent."""
|
|
64
|
+
source_dir = _PKG_PARENT / "sam2_checkpoints"
|
|
65
|
+
if source_dir.exists():
|
|
66
|
+
return source_dir
|
|
67
|
+
user_dir = USER_DATA_DIR / "sam2_checkpoints"
|
|
68
|
+
user_dir.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
return user_dir
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_training_profiles_path() -> Path:
|
|
73
|
+
"""Locate training_profiles.json."""
|
|
74
|
+
return _first_existing(
|
|
75
|
+
_PKG_DIR / "data" / "training_profiles.json", # package data
|
|
76
|
+
USER_DATA_DIR / "training_profiles.json", # user copy
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_default_config_path() -> Path:
|
|
81
|
+
"""Locate the default (template) config.yaml."""
|
|
82
|
+
return _first_existing(
|
|
83
|
+
_PKG_PARENT / "config" / "config.yaml", # source / zip install
|
|
84
|
+
_PKG_DIR / "data" / "config" / "config.yaml", # package data
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_experiments_dir() -> Path:
|
|
89
|
+
"""
|
|
90
|
+
Default experiments root directory.
|
|
91
|
+
|
|
92
|
+
- Source / zip install: SingleBehaviorLab/experiments/
|
|
93
|
+
- pip install: ~/SingleBehaviorLab/experiments/
|
|
94
|
+
"""
|
|
95
|
+
local = _PKG_PARENT / "experiments"
|
|
96
|
+
if local.exists():
|
|
97
|
+
return local
|
|
98
|
+
user_exp = USER_DATA_DIR / "experiments"
|
|
99
|
+
user_exp.mkdir(parents=True, exist_ok=True)
|
|
100
|
+
return user_exp
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torchvision.transforms.functional as F
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ClipAugment(nn.Module):
|
|
8
|
+
"""Applies identical augmentation to every frame; no cropping to avoid cutting off tracked subjects."""
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
use_horizontal_flip: bool = True,
|
|
13
|
+
use_vertical_flip: bool = False,
|
|
14
|
+
use_color_jitter: bool = True,
|
|
15
|
+
use_gaussian_blur: bool = False,
|
|
16
|
+
use_random_noise: bool = False,
|
|
17
|
+
use_small_rotation: bool = False,
|
|
18
|
+
use_speed_perturb: bool = False,
|
|
19
|
+
use_random_shapes: bool = False,
|
|
20
|
+
use_grayscale: bool = False,
|
|
21
|
+
use_lighting_robustness: bool = False,
|
|
22
|
+
color_jitter_brightness: float = 0.2,
|
|
23
|
+
color_jitter_contrast: float = 0.2,
|
|
24
|
+
color_jitter_saturation: float = 0.2,
|
|
25
|
+
color_jitter_hue: float = 0.1,
|
|
26
|
+
gaussian_blur_sigma: tuple[float, float] = (0.1, 0.5),
|
|
27
|
+
noise_std: float = 0.01,
|
|
28
|
+
rotation_degrees: float = 2.0,
|
|
29
|
+
speed_range: tuple[float, float] = (0.7, 1.3),
|
|
30
|
+
random_shapes_max: int = 3,
|
|
31
|
+
random_shapes_max_size: float = 0.15,
|
|
32
|
+
grayscale_prob: float = 0.5,
|
|
33
|
+
gamma_range: tuple[float, float] = (0.75, 1.35),
|
|
34
|
+
channel_gain_range: tuple[float, float] = (0.85, 1.15),
|
|
35
|
+
):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.use_horizontal_flip = use_horizontal_flip
|
|
38
|
+
self.use_vertical_flip = use_vertical_flip
|
|
39
|
+
self.use_color_jitter = use_color_jitter
|
|
40
|
+
self.use_gaussian_blur = use_gaussian_blur
|
|
41
|
+
self.use_random_noise = use_random_noise
|
|
42
|
+
self.use_small_rotation = use_small_rotation
|
|
43
|
+
self.use_speed_perturb = use_speed_perturb
|
|
44
|
+
self.use_random_shapes = use_random_shapes
|
|
45
|
+
self.use_grayscale = use_grayscale
|
|
46
|
+
self.use_lighting_robustness = use_lighting_robustness
|
|
47
|
+
|
|
48
|
+
self.color_jitter_brightness = color_jitter_brightness
|
|
49
|
+
self.color_jitter_contrast = color_jitter_contrast
|
|
50
|
+
self.color_jitter_saturation = color_jitter_saturation
|
|
51
|
+
self.color_jitter_hue = color_jitter_hue
|
|
52
|
+
self.gaussian_blur_sigma = gaussian_blur_sigma
|
|
53
|
+
self.noise_std = noise_std
|
|
54
|
+
self.rotation_degrees = rotation_degrees
|
|
55
|
+
self.speed_range = speed_range
|
|
56
|
+
self.random_shapes_max = max(1, random_shapes_max)
|
|
57
|
+
self.random_shapes_max_size = random_shapes_max_size
|
|
58
|
+
self.grayscale_prob = grayscale_prob
|
|
59
|
+
self.gamma_range = gamma_range
|
|
60
|
+
self.channel_gain_range = channel_gain_range
|
|
61
|
+
|
|
62
|
+
def _sample_params(self) -> dict:
|
|
63
|
+
"""Sample one set of augmentation params shared by all frames."""
|
|
64
|
+
hflip = self.use_horizontal_flip and (torch.rand(1).item() < 0.5)
|
|
65
|
+
vflip = self.use_vertical_flip and (torch.rand(1).item() < 0.5)
|
|
66
|
+
|
|
67
|
+
if self.use_color_jitter:
|
|
68
|
+
brightness_factor = torch.empty(1).uniform_(
|
|
69
|
+
1 - self.color_jitter_brightness,
|
|
70
|
+
1 + self.color_jitter_brightness
|
|
71
|
+
).item() if self.color_jitter_brightness > 0 else 1.0
|
|
72
|
+
|
|
73
|
+
contrast_factor = torch.empty(1).uniform_(
|
|
74
|
+
1 - self.color_jitter_contrast,
|
|
75
|
+
1 + self.color_jitter_contrast
|
|
76
|
+
).item() if self.color_jitter_contrast > 0 else 1.0
|
|
77
|
+
|
|
78
|
+
saturation_factor = torch.empty(1).uniform_(
|
|
79
|
+
1 - self.color_jitter_saturation,
|
|
80
|
+
1 + self.color_jitter_saturation
|
|
81
|
+
).item() if self.color_jitter_saturation > 0 else 1.0
|
|
82
|
+
|
|
83
|
+
hue_factor = torch.empty(1).uniform_(
|
|
84
|
+
-self.color_jitter_hue,
|
|
85
|
+
self.color_jitter_hue
|
|
86
|
+
).item() if self.color_jitter_hue > 0 else 0.0
|
|
87
|
+
else:
|
|
88
|
+
brightness_factor = contrast_factor = saturation_factor = hue_factor = None
|
|
89
|
+
|
|
90
|
+
if self.use_gaussian_blur:
|
|
91
|
+
blur_sigma = torch.empty(1).uniform_(
|
|
92
|
+
self.gaussian_blur_sigma[0],
|
|
93
|
+
self.gaussian_blur_sigma[1]
|
|
94
|
+
).item()
|
|
95
|
+
else:
|
|
96
|
+
blur_sigma = None
|
|
97
|
+
|
|
98
|
+
if self.use_small_rotation:
|
|
99
|
+
rotation_angle = torch.empty(1).uniform_(
|
|
100
|
+
-self.rotation_degrees,
|
|
101
|
+
self.rotation_degrees
|
|
102
|
+
).item()
|
|
103
|
+
else:
|
|
104
|
+
rotation_angle = None
|
|
105
|
+
|
|
106
|
+
if self.use_random_noise:
|
|
107
|
+
noise_std = self.noise_std
|
|
108
|
+
else:
|
|
109
|
+
noise_std = None
|
|
110
|
+
|
|
111
|
+
gamma_factor = None
|
|
112
|
+
channel_gains = None
|
|
113
|
+
if self.use_lighting_robustness:
|
|
114
|
+
gamma_factor = torch.empty(1).uniform_(
|
|
115
|
+
self.gamma_range[0],
|
|
116
|
+
self.gamma_range[1]
|
|
117
|
+
).item()
|
|
118
|
+
channel_gains = torch.empty(3).uniform_(
|
|
119
|
+
self.channel_gain_range[0],
|
|
120
|
+
self.channel_gain_range[1]
|
|
121
|
+
).tolist()
|
|
122
|
+
|
|
123
|
+
# Speed perturbation: sample a speed factor once per clip
|
|
124
|
+
speed_factor = None
|
|
125
|
+
if self.use_speed_perturb:
|
|
126
|
+
speed_factor = torch.empty(1).uniform_(
|
|
127
|
+
self.speed_range[0], self.speed_range[1]
|
|
128
|
+
).item()
|
|
129
|
+
|
|
130
|
+
# Random shapes: sample positions, sizes, colors, types once per clip
|
|
131
|
+
shapes = None
|
|
132
|
+
if self.use_random_shapes:
|
|
133
|
+
n_shapes = torch.randint(1, self.random_shapes_max + 1, (1,)).item()
|
|
134
|
+
shapes = []
|
|
135
|
+
for _ in range(n_shapes):
|
|
136
|
+
shape_type = torch.randint(0, 3, (1,)).item() # 0=rect, 1=ellipse, 2=triangle
|
|
137
|
+
cx = torch.rand(1).item()
|
|
138
|
+
cy = torch.rand(1).item()
|
|
139
|
+
sw = torch.empty(1).uniform_(0.03, self.random_shapes_max_size).item()
|
|
140
|
+
sh = torch.empty(1).uniform_(0.03, self.random_shapes_max_size).item()
|
|
141
|
+
color = torch.rand(3).tolist()
|
|
142
|
+
shapes.append({
|
|
143
|
+
"type": shape_type, "cx": cx, "cy": cy,
|
|
144
|
+
"sw": sw, "sh": sh, "color": color,
|
|
145
|
+
})
|
|
146
|
+
|
|
147
|
+
do_grayscale = False
|
|
148
|
+
if self.use_grayscale:
|
|
149
|
+
do_grayscale = torch.rand(1).item() < self.grayscale_prob
|
|
150
|
+
|
|
151
|
+
return {
|
|
152
|
+
"hflip": bool(hflip),
|
|
153
|
+
"vflip": bool(vflip),
|
|
154
|
+
"brightness_factor": brightness_factor,
|
|
155
|
+
"contrast_factor": contrast_factor,
|
|
156
|
+
"saturation_factor": saturation_factor,
|
|
157
|
+
"hue_factor": hue_factor,
|
|
158
|
+
"blur_sigma": blur_sigma,
|
|
159
|
+
"rotation_angle": rotation_angle,
|
|
160
|
+
"noise_std": noise_std,
|
|
161
|
+
"gamma_factor": gamma_factor,
|
|
162
|
+
"channel_gains": channel_gains,
|
|
163
|
+
"speed_factor": speed_factor,
|
|
164
|
+
"shapes": shapes,
|
|
165
|
+
"grayscale": do_grayscale,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def _resample_temporal(clip: torch.Tensor, speed_factor: float) -> torch.Tensor:
|
|
170
|
+
"""Resample clip frames to simulate speed change. Output has same T."""
|
|
171
|
+
T = clip.shape[0]
|
|
172
|
+
if T <= 1 or abs(speed_factor - 1.0) < 0.01:
|
|
173
|
+
return clip
|
|
174
|
+
# At speed_factor > 1 we want to cover more source frames (speed up),
|
|
175
|
+
# so we sample from a wider window (indices can exceed T-1 → clamp).
|
|
176
|
+
# At speed_factor < 1 we cover fewer source frames (slow down),
|
|
177
|
+
# so indices cluster in the center.
|
|
178
|
+
src_indices = torch.linspace(0, (T - 1) * speed_factor, T)
|
|
179
|
+
src_indices = src_indices.clamp(0, T - 1)
|
|
180
|
+
idx = src_indices.round().long()
|
|
181
|
+
return clip[idx]
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def _draw_shapes_on_frame(frame: torch.Tensor, shapes: list, H: int, W: int) -> torch.Tensor:
|
|
185
|
+
"""Draw pre-sampled shapes onto a single frame [C, H, W]."""
|
|
186
|
+
frame = frame.clone()
|
|
187
|
+
for s in shapes:
|
|
188
|
+
cx_px = int(s["cx"] * W)
|
|
189
|
+
cy_px = int(s["cy"] * H)
|
|
190
|
+
half_w = max(1, int(s["sw"] * W / 2))
|
|
191
|
+
half_h = max(1, int(s["sh"] * H / 2))
|
|
192
|
+
color = s["color"]
|
|
193
|
+
|
|
194
|
+
x1 = max(0, cx_px - half_w)
|
|
195
|
+
x2 = min(W, cx_px + half_w)
|
|
196
|
+
y1 = max(0, cy_px - half_h)
|
|
197
|
+
y2 = min(H, cy_px + half_h)
|
|
198
|
+
if x2 <= x1 or y2 <= y1:
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
stype = s["type"]
|
|
202
|
+
if stype == 0:
|
|
203
|
+
for c_i in range(min(3, frame.shape[0])):
|
|
204
|
+
frame[c_i, y1:y2, x1:x2] = color[c_i]
|
|
205
|
+
elif stype == 1:
|
|
206
|
+
yy = torch.arange(y1, y2, device=frame.device).float()
|
|
207
|
+
xx = torch.arange(x1, x2, device=frame.device).float()
|
|
208
|
+
gy, gx = torch.meshgrid(yy, xx, indexing="ij")
|
|
209
|
+
ey = (gy - cy_px) / max(half_h, 1)
|
|
210
|
+
ex = (gx - cx_px) / max(half_w, 1)
|
|
211
|
+
mask = (ex ** 2 + ey ** 2) <= 1.0
|
|
212
|
+
for c_i in range(min(3, frame.shape[0])):
|
|
213
|
+
region = frame[c_i, y1:y2, x1:x2]
|
|
214
|
+
region[mask] = color[c_i]
|
|
215
|
+
elif stype == 2:
|
|
216
|
+
# Triangle pointing up: apex at top center, base at bottom
|
|
217
|
+
yy = torch.arange(y1, y2, device=frame.device).float()
|
|
218
|
+
xx = torch.arange(x1, x2, device=frame.device).float()
|
|
219
|
+
gy, gx = torch.meshgrid(yy, xx, indexing="ij")
|
|
220
|
+
ny = (gy - y1) / max(y2 - y1 - 1, 1)
|
|
221
|
+
nx = (gx - x1) / max(x2 - x1 - 1, 1)
|
|
222
|
+
mask = (nx >= 0.5 - 0.5 * ny) & (nx <= 0.5 + 0.5 * ny)
|
|
223
|
+
for c_i in range(min(3, frame.shape[0])):
|
|
224
|
+
region = frame[c_i, y1:y2, x1:x2]
|
|
225
|
+
region[mask] = color[c_i]
|
|
226
|
+
return frame
|
|
227
|
+
|
|
228
|
+
def _apply_with_params(self, clip: torch.Tensor, params: dict) -> torch.Tensor:
|
|
229
|
+
T, C, H, W = clip.shape
|
|
230
|
+
|
|
231
|
+
# Speed perturbation (resamples frames, applied before per-frame ops)
|
|
232
|
+
speed_factor = params.get("speed_factor", None)
|
|
233
|
+
if speed_factor is not None:
|
|
234
|
+
clip = self._resample_temporal(clip, speed_factor)
|
|
235
|
+
|
|
236
|
+
hflip = bool(params.get("hflip", params.get("flip", False)))
|
|
237
|
+
vflip = bool(params.get("vflip", False))
|
|
238
|
+
brightness_factor = params.get("brightness_factor", None)
|
|
239
|
+
contrast_factor = params.get("contrast_factor", None)
|
|
240
|
+
saturation_factor = params.get("saturation_factor", None)
|
|
241
|
+
hue_factor = params.get("hue_factor", None)
|
|
242
|
+
blur_sigma = params.get("blur_sigma", None)
|
|
243
|
+
rotation_angle = params.get("rotation_angle", None)
|
|
244
|
+
noise_std = params.get("noise_std", None)
|
|
245
|
+
gamma_factor = params.get("gamma_factor", None)
|
|
246
|
+
channel_gains = params.get("channel_gains", None)
|
|
247
|
+
shapes = params.get("shapes", None)
|
|
248
|
+
do_grayscale = params.get("grayscale", False)
|
|
249
|
+
|
|
250
|
+
augmented_frames = []
|
|
251
|
+
for t in range(T):
|
|
252
|
+
frame = clip[t]
|
|
253
|
+
|
|
254
|
+
if hflip:
|
|
255
|
+
frame = F.hflip(frame)
|
|
256
|
+
if vflip:
|
|
257
|
+
frame = F.vflip(frame)
|
|
258
|
+
|
|
259
|
+
if do_grayscale:
|
|
260
|
+
frame = F.rgb_to_grayscale(frame, num_output_channels=C)
|
|
261
|
+
|
|
262
|
+
if brightness_factor is not None:
|
|
263
|
+
frame = F.adjust_brightness(frame, brightness_factor)
|
|
264
|
+
frame = F.adjust_contrast(frame, contrast_factor)
|
|
265
|
+
frame = F.adjust_saturation(frame, saturation_factor)
|
|
266
|
+
frame = F.adjust_hue(frame, hue_factor)
|
|
267
|
+
|
|
268
|
+
# Gamma correction + per-channel gain shifts for lighting robustness
|
|
269
|
+
if gamma_factor is not None:
|
|
270
|
+
frame = torch.clamp(frame, 0.0, 1.0)
|
|
271
|
+
frame = torch.pow(frame + 1e-6, gamma_factor)
|
|
272
|
+
if channel_gains is not None and C >= 3:
|
|
273
|
+
gains = torch.tensor(channel_gains[:3], device=frame.device, dtype=frame.dtype).view(3, 1, 1)
|
|
274
|
+
frame[:3] = frame[:3] * gains
|
|
275
|
+
frame = torch.clamp(frame, 0.0, 1.0)
|
|
276
|
+
|
|
277
|
+
if rotation_angle is not None and abs(rotation_angle) > 0.01:
|
|
278
|
+
frame = F.rotate(frame, rotation_angle, interpolation=F.InterpolationMode.BILINEAR, fill=0.0)
|
|
279
|
+
|
|
280
|
+
if blur_sigma is not None and blur_sigma > 0.01:
|
|
281
|
+
kernel_size = int(2 * int(4 * blur_sigma + 0.5) + 1)
|
|
282
|
+
if kernel_size >= 3:
|
|
283
|
+
frame = F.gaussian_blur(frame, kernel_size=[kernel_size, kernel_size], sigma=[blur_sigma, blur_sigma])
|
|
284
|
+
|
|
285
|
+
# Random shapes (same position/color for every frame in the clip)
|
|
286
|
+
if shapes:
|
|
287
|
+
frame = self._draw_shapes_on_frame(frame, shapes, H, W)
|
|
288
|
+
|
|
289
|
+
if noise_std is not None and noise_std > 0:
|
|
290
|
+
noise = torch.randn_like(frame) * noise_std
|
|
291
|
+
frame = torch.clamp(frame + noise, 0.0, 1.0)
|
|
292
|
+
|
|
293
|
+
augmented_frames.append(frame)
|
|
294
|
+
|
|
295
|
+
return torch.stack(augmented_frames)
|
|
296
|
+
|
|
297
|
+
def augment_with_params(self, clip: torch.Tensor):
|
|
298
|
+
"""Return augmented clip and parameter dict for synchronizing spatial label transforms (e.g., bboxes, masks)."""
|
|
299
|
+
params = self._sample_params()
|
|
300
|
+
return self._apply_with_params(clip, params), params
|
|
301
|
+
|
|
302
|
+
def forward(self, clip: torch.Tensor) -> torch.Tensor:
|
|
303
|
+
"""clip: [T, C, H, W] float tensor in [0, 1]."""
|
|
304
|
+
augmented_clip, _ = self.augment_with_params(clip)
|
|
305
|
+
return augmented_clip
|
|
306
|
+
|
|
307
|
+
def __repr__(self):
|
|
308
|
+
parts = [
|
|
309
|
+
f"hflip={self.use_horizontal_flip}",
|
|
310
|
+
f"vflip={self.use_vertical_flip}",
|
|
311
|
+
f"color_jitter={self.use_color_jitter}",
|
|
312
|
+
f"gaussian_blur={self.use_gaussian_blur}",
|
|
313
|
+
f"random_noise={self.use_random_noise}",
|
|
314
|
+
f"small_rotation={self.use_small_rotation}",
|
|
315
|
+
f"speed_perturb={self.use_speed_perturb}",
|
|
316
|
+
f"random_shapes={self.use_random_shapes}",
|
|
317
|
+
f"grayscale={self.use_grayscale}",
|
|
318
|
+
f"lighting_robustness={self.use_lighting_robustness}",
|
|
319
|
+
]
|
|
320
|
+
return f"ClipAugment({', '.join(parts)})"
|