singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# @package _global_
|
|
2
|
+
|
|
3
|
+
scratch:
|
|
4
|
+
resolution: 1024
|
|
5
|
+
train_batch_size: 1
|
|
6
|
+
num_train_workers: 10
|
|
7
|
+
num_frames: 8
|
|
8
|
+
max_num_objects: 3
|
|
9
|
+
base_lr: 5.0e-6
|
|
10
|
+
vision_lr: 3.0e-06
|
|
11
|
+
phases_per_epoch: 1
|
|
12
|
+
num_epochs: 40
|
|
13
|
+
|
|
14
|
+
dataset:
|
|
15
|
+
# PATHS to Dataset
|
|
16
|
+
img_folder: null # PATH to MOSE JPEGImages folder
|
|
17
|
+
gt_folder: null # PATH to MOSE Annotations folder
|
|
18
|
+
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
|
|
19
|
+
multiplier: 2
|
|
20
|
+
|
|
21
|
+
# Video transforms
|
|
22
|
+
vos:
|
|
23
|
+
train_transforms:
|
|
24
|
+
- _target_: training.dataset.transforms.ComposeAPI
|
|
25
|
+
transforms:
|
|
26
|
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
|
27
|
+
consistent_transform: True
|
|
28
|
+
- _target_: training.dataset.transforms.RandomAffine
|
|
29
|
+
degrees: 25
|
|
30
|
+
shear: 20
|
|
31
|
+
image_interpolation: bilinear
|
|
32
|
+
consistent_transform: True
|
|
33
|
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
|
34
|
+
sizes: ${scratch.resolution}
|
|
35
|
+
square: true
|
|
36
|
+
consistent_transform: True
|
|
37
|
+
- _target_: training.dataset.transforms.ColorJitter
|
|
38
|
+
consistent_transform: True
|
|
39
|
+
brightness: 0.1
|
|
40
|
+
contrast: 0.03
|
|
41
|
+
saturation: 0.03
|
|
42
|
+
hue: null
|
|
43
|
+
- _target_: training.dataset.transforms.RandomGrayscale
|
|
44
|
+
p: 0.05
|
|
45
|
+
consistent_transform: True
|
|
46
|
+
- _target_: training.dataset.transforms.ColorJitter
|
|
47
|
+
consistent_transform: False
|
|
48
|
+
brightness: 0.1
|
|
49
|
+
contrast: 0.05
|
|
50
|
+
saturation: 0.05
|
|
51
|
+
hue: null
|
|
52
|
+
- _target_: training.dataset.transforms.ToTensorAPI
|
|
53
|
+
- _target_: training.dataset.transforms.NormalizeAPI
|
|
54
|
+
mean: [0.485, 0.456, 0.406]
|
|
55
|
+
std: [0.229, 0.224, 0.225]
|
|
56
|
+
|
|
57
|
+
trainer:
|
|
58
|
+
_target_: training.trainer.Trainer
|
|
59
|
+
mode: train_only
|
|
60
|
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
|
61
|
+
accelerator: cuda
|
|
62
|
+
seed_value: 123
|
|
63
|
+
|
|
64
|
+
model:
|
|
65
|
+
_target_: training.model.sam2.SAM2Train
|
|
66
|
+
image_encoder:
|
|
67
|
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
|
68
|
+
scalp: 1
|
|
69
|
+
trunk:
|
|
70
|
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
|
71
|
+
embed_dim: 112
|
|
72
|
+
num_heads: 2
|
|
73
|
+
drop_path_rate: 0.1
|
|
74
|
+
neck:
|
|
75
|
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
|
76
|
+
position_encoding:
|
|
77
|
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
|
78
|
+
num_pos_feats: 256
|
|
79
|
+
normalize: true
|
|
80
|
+
scale: null
|
|
81
|
+
temperature: 10000
|
|
82
|
+
d_model: 256
|
|
83
|
+
backbone_channel_list: [896, 448, 224, 112]
|
|
84
|
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
|
85
|
+
fpn_interp_model: nearest
|
|
86
|
+
|
|
87
|
+
memory_attention:
|
|
88
|
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
|
89
|
+
d_model: 256
|
|
90
|
+
pos_enc_at_input: true
|
|
91
|
+
layer:
|
|
92
|
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
|
93
|
+
activation: relu
|
|
94
|
+
dim_feedforward: 2048
|
|
95
|
+
dropout: 0.1
|
|
96
|
+
pos_enc_at_attn: false
|
|
97
|
+
self_attention:
|
|
98
|
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
|
99
|
+
rope_theta: 10000.0
|
|
100
|
+
feat_sizes: [64, 64]
|
|
101
|
+
embedding_dim: 256
|
|
102
|
+
num_heads: 1
|
|
103
|
+
downsample_rate: 1
|
|
104
|
+
dropout: 0.1
|
|
105
|
+
d_model: 256
|
|
106
|
+
pos_enc_at_cross_attn_keys: true
|
|
107
|
+
pos_enc_at_cross_attn_queries: false
|
|
108
|
+
cross_attention:
|
|
109
|
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
|
110
|
+
rope_theta: 10000.0
|
|
111
|
+
feat_sizes: [64, 64]
|
|
112
|
+
rope_k_repeat: True
|
|
113
|
+
embedding_dim: 256
|
|
114
|
+
num_heads: 1
|
|
115
|
+
downsample_rate: 1
|
|
116
|
+
dropout: 0.1
|
|
117
|
+
kv_in_dim: 64
|
|
118
|
+
num_layers: 4
|
|
119
|
+
|
|
120
|
+
memory_encoder:
|
|
121
|
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
|
122
|
+
out_dim: 64
|
|
123
|
+
position_encoding:
|
|
124
|
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
|
125
|
+
num_pos_feats: 64
|
|
126
|
+
normalize: true
|
|
127
|
+
scale: null
|
|
128
|
+
temperature: 10000
|
|
129
|
+
mask_downsampler:
|
|
130
|
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
|
131
|
+
kernel_size: 3
|
|
132
|
+
stride: 2
|
|
133
|
+
padding: 1
|
|
134
|
+
fuser:
|
|
135
|
+
_target_: sam2.modeling.memory_encoder.Fuser
|
|
136
|
+
layer:
|
|
137
|
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
|
138
|
+
dim: 256
|
|
139
|
+
kernel_size: 7
|
|
140
|
+
padding: 3
|
|
141
|
+
layer_scale_init_value: 1e-6
|
|
142
|
+
use_dwconv: True # depth-wise convs
|
|
143
|
+
num_layers: 2
|
|
144
|
+
|
|
145
|
+
num_maskmem: 7
|
|
146
|
+
image_size: ${scratch.resolution}
|
|
147
|
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
|
148
|
+
sigmoid_scale_for_mem_enc: 20.0
|
|
149
|
+
sigmoid_bias_for_mem_enc: -10.0
|
|
150
|
+
use_mask_input_as_output_without_sam: true
|
|
151
|
+
# Memory
|
|
152
|
+
directly_add_no_mem_embed: true
|
|
153
|
+
no_obj_embed_spatial: true
|
|
154
|
+
# use high-resolution feature map in the SAM mask decoder
|
|
155
|
+
use_high_res_features_in_sam: true
|
|
156
|
+
# output 3 masks on the first click on initial conditioning frames
|
|
157
|
+
multimask_output_in_sam: true
|
|
158
|
+
# SAM heads
|
|
159
|
+
iou_prediction_use_sigmoid: True
|
|
160
|
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
|
161
|
+
use_obj_ptrs_in_encoder: true
|
|
162
|
+
add_tpos_enc_to_obj_ptrs: true
|
|
163
|
+
proj_tpos_enc_in_obj_ptrs: true
|
|
164
|
+
use_signed_tpos_enc_to_obj_ptrs: true
|
|
165
|
+
only_obj_ptrs_in_the_past_for_eval: true
|
|
166
|
+
# object occlusion prediction
|
|
167
|
+
pred_obj_scores: true
|
|
168
|
+
pred_obj_scores_mlp: true
|
|
169
|
+
fixed_no_obj_ptr: true
|
|
170
|
+
# multimask tracking settings
|
|
171
|
+
multimask_output_for_tracking: true
|
|
172
|
+
use_multimask_token_for_obj_ptr: true
|
|
173
|
+
multimask_min_pt_num: 0
|
|
174
|
+
multimask_max_pt_num: 1
|
|
175
|
+
use_mlp_for_obj_ptr_proj: true
|
|
176
|
+
# Compilation flag
|
|
177
|
+
# compile_image_encoder: False
|
|
178
|
+
|
|
179
|
+
####### Training specific params #######
|
|
180
|
+
# box/point input and corrections
|
|
181
|
+
prob_to_use_pt_input_for_train: 0.5
|
|
182
|
+
prob_to_use_pt_input_for_eval: 0.0
|
|
183
|
+
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
|
|
184
|
+
prob_to_use_box_input_for_eval: 0.0
|
|
185
|
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
|
186
|
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
|
187
|
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
|
188
|
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
|
189
|
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
|
190
|
+
# maximum 2 initial conditioning frames
|
|
191
|
+
num_init_cond_frames_for_train: 2
|
|
192
|
+
rand_init_cond_frames_for_train: True # random 1~2
|
|
193
|
+
num_correction_pt_per_frame: 7
|
|
194
|
+
use_act_ckpt_iterative_pt_sampling: false
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
|
199
|
+
forward_backbone_per_frame_for_eval: True
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
data:
|
|
203
|
+
train:
|
|
204
|
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
|
205
|
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
|
206
|
+
batch_sizes:
|
|
207
|
+
- ${scratch.train_batch_size}
|
|
208
|
+
|
|
209
|
+
datasets:
|
|
210
|
+
- _target_: training.dataset.utils.RepeatFactorWrapper
|
|
211
|
+
dataset:
|
|
212
|
+
_target_: training.dataset.utils.ConcatDataset
|
|
213
|
+
datasets:
|
|
214
|
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
|
215
|
+
transforms: ${vos.train_transforms}
|
|
216
|
+
training: true
|
|
217
|
+
video_dataset:
|
|
218
|
+
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
|
|
219
|
+
img_folder: ${dataset.img_folder}
|
|
220
|
+
gt_folder: ${dataset.gt_folder}
|
|
221
|
+
file_list_txt: ${dataset.file_list_txt}
|
|
222
|
+
sampler:
|
|
223
|
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
|
224
|
+
num_frames: ${scratch.num_frames}
|
|
225
|
+
max_num_objects: ${scratch.max_num_objects}
|
|
226
|
+
multiplier: ${dataset.multiplier}
|
|
227
|
+
shuffle: True
|
|
228
|
+
num_workers: ${scratch.num_train_workers}
|
|
229
|
+
pin_memory: True
|
|
230
|
+
drop_last: True
|
|
231
|
+
collate_fn:
|
|
232
|
+
_target_: training.utils.data_utils.collate_fn
|
|
233
|
+
_partial_: true
|
|
234
|
+
dict_key: all
|
|
235
|
+
|
|
236
|
+
optim:
|
|
237
|
+
amp:
|
|
238
|
+
enabled: True
|
|
239
|
+
amp_dtype: bfloat16
|
|
240
|
+
|
|
241
|
+
optimizer:
|
|
242
|
+
_target_: torch.optim.AdamW
|
|
243
|
+
|
|
244
|
+
gradient_clip:
|
|
245
|
+
_target_: training.optimizer.GradientClipper
|
|
246
|
+
max_norm: 0.1
|
|
247
|
+
norm_type: 2
|
|
248
|
+
|
|
249
|
+
param_group_modifiers:
|
|
250
|
+
- _target_: training.optimizer.layer_decay_param_modifier
|
|
251
|
+
_partial_: True
|
|
252
|
+
layer_decay_value: 0.9
|
|
253
|
+
apply_to: 'image_encoder.trunk'
|
|
254
|
+
overrides:
|
|
255
|
+
- pattern: '*pos_embed*'
|
|
256
|
+
value: 1.0
|
|
257
|
+
|
|
258
|
+
options:
|
|
259
|
+
lr:
|
|
260
|
+
- scheduler:
|
|
261
|
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
|
262
|
+
start_value: ${scratch.base_lr}
|
|
263
|
+
end_value: ${divide:${scratch.base_lr},10}
|
|
264
|
+
- scheduler:
|
|
265
|
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
|
266
|
+
start_value: ${scratch.vision_lr}
|
|
267
|
+
end_value: ${divide:${scratch.vision_lr},10}
|
|
268
|
+
param_names:
|
|
269
|
+
- 'image_encoder.*'
|
|
270
|
+
weight_decay:
|
|
271
|
+
- scheduler:
|
|
272
|
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
|
273
|
+
value: 0.1
|
|
274
|
+
- scheduler:
|
|
275
|
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
|
276
|
+
value: 0.0
|
|
277
|
+
param_names:
|
|
278
|
+
- '*bias*'
|
|
279
|
+
module_cls_names: ['torch.nn.LayerNorm']
|
|
280
|
+
|
|
281
|
+
loss:
|
|
282
|
+
all:
|
|
283
|
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
|
284
|
+
weight_dict:
|
|
285
|
+
loss_mask: 20
|
|
286
|
+
loss_dice: 1
|
|
287
|
+
loss_iou: 1
|
|
288
|
+
loss_class: 1
|
|
289
|
+
supervise_all_iou: true
|
|
290
|
+
iou_use_l1_loss: true
|
|
291
|
+
pred_obj_scores: true
|
|
292
|
+
focal_gamma_obj_score: 0.0
|
|
293
|
+
focal_alpha_obj_score: -1.0
|
|
294
|
+
|
|
295
|
+
distributed:
|
|
296
|
+
backend: nccl
|
|
297
|
+
find_unused_parameters: True
|
|
298
|
+
|
|
299
|
+
logging:
|
|
300
|
+
tensorboard_writer:
|
|
301
|
+
_target_: training.utils.logger.make_tensorboard_logger
|
|
302
|
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
|
303
|
+
flush_secs: 120
|
|
304
|
+
should_log: True
|
|
305
|
+
log_dir: ${launcher.experiment_log_dir}/logs
|
|
306
|
+
log_freq: 10
|
|
307
|
+
|
|
308
|
+
# initialize from a SAM 2 checkpoint
|
|
309
|
+
checkpoint:
|
|
310
|
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
|
311
|
+
save_freq: 0 # 0 only last checkpoint is saved.
|
|
312
|
+
model_weight_initializer:
|
|
313
|
+
_partial_: True
|
|
314
|
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
|
315
|
+
strict: True
|
|
316
|
+
ignore_unexpected_keys: null
|
|
317
|
+
ignore_missing_keys: null
|
|
318
|
+
|
|
319
|
+
state_dict:
|
|
320
|
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
|
321
|
+
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
|
322
|
+
ckpt_state_dict_keys: ['model']
|
|
323
|
+
|
|
324
|
+
launcher:
|
|
325
|
+
num_nodes: 1
|
|
326
|
+
gpus_per_node: 8
|
|
327
|
+
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
|
328
|
+
|
|
329
|
+
# SLURM args if running on a cluster
|
|
330
|
+
submitit:
|
|
331
|
+
partition: null
|
|
332
|
+
account: null
|
|
333
|
+
qos: null
|
|
334
|
+
cpus_per_task: 10
|
|
335
|
+
use_cluster: false
|
|
336
|
+
timeout_hour: 24
|
|
337
|
+
name: null
|
|
338
|
+
port_range: [10000, 65000]
|
|
339
|
+
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import List, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
from iopath.common.file_io import g_pathmgr
|
|
15
|
+
|
|
16
|
+
from sam2.modeling.backbones.utils import (
|
|
17
|
+
PatchEmbed,
|
|
18
|
+
window_partition,
|
|
19
|
+
window_unpartition,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from sam2.modeling.sam2_utils import DropPath, MLP
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
|
26
|
+
if pool is None:
|
|
27
|
+
return x
|
|
28
|
+
# (B, H, W, C) -> (B, C, H, W)
|
|
29
|
+
x = x.permute(0, 3, 1, 2)
|
|
30
|
+
x = pool(x)
|
|
31
|
+
# (B, C, H', W') -> (B, H', W', C)
|
|
32
|
+
x = x.permute(0, 2, 3, 1)
|
|
33
|
+
if norm:
|
|
34
|
+
x = norm(x)
|
|
35
|
+
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MultiScaleAttention(nn.Module):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
dim: int,
|
|
43
|
+
dim_out: int,
|
|
44
|
+
num_heads: int,
|
|
45
|
+
q_pool: nn.Module = None,
|
|
46
|
+
):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
self.dim = dim
|
|
50
|
+
self.dim_out = dim_out
|
|
51
|
+
self.num_heads = num_heads
|
|
52
|
+
self.q_pool = q_pool
|
|
53
|
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
|
54
|
+
self.proj = nn.Linear(dim_out, dim_out)
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
B, H, W, _ = x.shape
|
|
58
|
+
# qkv with shape (B, H * W, 3, nHead, C)
|
|
59
|
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
|
60
|
+
# q, k, v with shape (B, H * W, nheads, C)
|
|
61
|
+
q, k, v = torch.unbind(qkv, 2)
|
|
62
|
+
|
|
63
|
+
# Q pooling (for downsample at stage changes)
|
|
64
|
+
if self.q_pool:
|
|
65
|
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
|
66
|
+
H, W = q.shape[1:3] # downsampled shape
|
|
67
|
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
|
68
|
+
|
|
69
|
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
|
70
|
+
x = F.scaled_dot_product_attention(
|
|
71
|
+
q.transpose(1, 2),
|
|
72
|
+
k.transpose(1, 2),
|
|
73
|
+
v.transpose(1, 2),
|
|
74
|
+
)
|
|
75
|
+
# Transpose back
|
|
76
|
+
x = x.transpose(1, 2)
|
|
77
|
+
x = x.reshape(B, H, W, -1)
|
|
78
|
+
|
|
79
|
+
x = self.proj(x)
|
|
80
|
+
|
|
81
|
+
return x
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MultiScaleBlock(nn.Module):
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
dim: int,
|
|
88
|
+
dim_out: int,
|
|
89
|
+
num_heads: int,
|
|
90
|
+
mlp_ratio: float = 4.0,
|
|
91
|
+
drop_path: float = 0.0,
|
|
92
|
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
|
93
|
+
q_stride: Tuple[int, int] = None,
|
|
94
|
+
act_layer: nn.Module = nn.GELU,
|
|
95
|
+
window_size: int = 0,
|
|
96
|
+
):
|
|
97
|
+
super().__init__()
|
|
98
|
+
|
|
99
|
+
if isinstance(norm_layer, str):
|
|
100
|
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
|
101
|
+
|
|
102
|
+
self.dim = dim
|
|
103
|
+
self.dim_out = dim_out
|
|
104
|
+
self.norm1 = norm_layer(dim)
|
|
105
|
+
|
|
106
|
+
self.window_size = window_size
|
|
107
|
+
|
|
108
|
+
self.pool, self.q_stride = None, q_stride
|
|
109
|
+
if self.q_stride:
|
|
110
|
+
self.pool = nn.MaxPool2d(
|
|
111
|
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.attn = MultiScaleAttention(
|
|
115
|
+
dim,
|
|
116
|
+
dim_out,
|
|
117
|
+
num_heads=num_heads,
|
|
118
|
+
q_pool=self.pool,
|
|
119
|
+
)
|
|
120
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
121
|
+
|
|
122
|
+
self.norm2 = norm_layer(dim_out)
|
|
123
|
+
self.mlp = MLP(
|
|
124
|
+
dim_out,
|
|
125
|
+
int(dim_out * mlp_ratio),
|
|
126
|
+
dim_out,
|
|
127
|
+
num_layers=2,
|
|
128
|
+
activation=act_layer,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if dim != dim_out:
|
|
132
|
+
self.proj = nn.Linear(dim, dim_out)
|
|
133
|
+
|
|
134
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
shortcut = x # B, H, W, C
|
|
136
|
+
x = self.norm1(x)
|
|
137
|
+
|
|
138
|
+
# Skip connection
|
|
139
|
+
if self.dim != self.dim_out:
|
|
140
|
+
shortcut = do_pool(self.proj(x), self.pool)
|
|
141
|
+
|
|
142
|
+
# Window partition
|
|
143
|
+
window_size = self.window_size
|
|
144
|
+
if window_size > 0:
|
|
145
|
+
H, W = x.shape[1], x.shape[2]
|
|
146
|
+
x, pad_hw = window_partition(x, window_size)
|
|
147
|
+
|
|
148
|
+
# Window Attention + Q Pooling (if stage change)
|
|
149
|
+
x = self.attn(x)
|
|
150
|
+
if self.q_stride:
|
|
151
|
+
# Shapes have changed due to Q pooling
|
|
152
|
+
window_size = self.window_size // self.q_stride[0]
|
|
153
|
+
H, W = shortcut.shape[1:3]
|
|
154
|
+
|
|
155
|
+
pad_h = (window_size - H % window_size) % window_size
|
|
156
|
+
pad_w = (window_size - W % window_size) % window_size
|
|
157
|
+
pad_hw = (H + pad_h, W + pad_w)
|
|
158
|
+
|
|
159
|
+
# Reverse window partition
|
|
160
|
+
if self.window_size > 0:
|
|
161
|
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
|
162
|
+
|
|
163
|
+
x = shortcut + self.drop_path(x)
|
|
164
|
+
# MLP
|
|
165
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
166
|
+
return x
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class Hiera(nn.Module):
|
|
170
|
+
"""
|
|
171
|
+
Reference: https://arxiv.org/abs/2306.00989
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
embed_dim: int = 96, # initial embed dim
|
|
177
|
+
num_heads: int = 1, # initial number of heads
|
|
178
|
+
drop_path_rate: float = 0.0, # stochastic depth
|
|
179
|
+
q_pool: int = 3, # number of q_pool stages
|
|
180
|
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
|
181
|
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
|
182
|
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
|
183
|
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
|
184
|
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
|
185
|
+
# window size per stage, when not using global att.
|
|
186
|
+
window_spec: Tuple[int, ...] = (
|
|
187
|
+
8,
|
|
188
|
+
4,
|
|
189
|
+
14,
|
|
190
|
+
7,
|
|
191
|
+
),
|
|
192
|
+
# global attn in these blocks
|
|
193
|
+
global_att_blocks: Tuple[int, ...] = (
|
|
194
|
+
12,
|
|
195
|
+
16,
|
|
196
|
+
20,
|
|
197
|
+
),
|
|
198
|
+
weights_path=None,
|
|
199
|
+
return_interm_layers=True, # return feats from every stage
|
|
200
|
+
):
|
|
201
|
+
super().__init__()
|
|
202
|
+
|
|
203
|
+
assert len(stages) == len(window_spec)
|
|
204
|
+
self.window_spec = window_spec
|
|
205
|
+
|
|
206
|
+
depth = sum(stages)
|
|
207
|
+
self.q_stride = q_stride
|
|
208
|
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
|
209
|
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
|
210
|
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
|
211
|
+
self.return_interm_layers = return_interm_layers
|
|
212
|
+
|
|
213
|
+
self.patch_embed = PatchEmbed(
|
|
214
|
+
embed_dim=embed_dim,
|
|
215
|
+
)
|
|
216
|
+
# Which blocks have global att?
|
|
217
|
+
self.global_att_blocks = global_att_blocks
|
|
218
|
+
|
|
219
|
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
|
220
|
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
|
221
|
+
self.pos_embed = nn.Parameter(
|
|
222
|
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
|
223
|
+
)
|
|
224
|
+
self.pos_embed_window = nn.Parameter(
|
|
225
|
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
dpr = [
|
|
229
|
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
|
230
|
+
] # stochastic depth decay rule
|
|
231
|
+
|
|
232
|
+
cur_stage = 1
|
|
233
|
+
self.blocks = nn.ModuleList()
|
|
234
|
+
|
|
235
|
+
for i in range(depth):
|
|
236
|
+
dim_out = embed_dim
|
|
237
|
+
# lags by a block, so first block of
|
|
238
|
+
# next stage uses an initial window size
|
|
239
|
+
# of previous stage and final window size of current stage
|
|
240
|
+
window_size = self.window_spec[cur_stage - 1]
|
|
241
|
+
|
|
242
|
+
if self.global_att_blocks is not None:
|
|
243
|
+
window_size = 0 if i in self.global_att_blocks else window_size
|
|
244
|
+
|
|
245
|
+
if i - 1 in self.stage_ends:
|
|
246
|
+
dim_out = int(embed_dim * dim_mul)
|
|
247
|
+
num_heads = int(num_heads * head_mul)
|
|
248
|
+
cur_stage += 1
|
|
249
|
+
|
|
250
|
+
block = MultiScaleBlock(
|
|
251
|
+
dim=embed_dim,
|
|
252
|
+
dim_out=dim_out,
|
|
253
|
+
num_heads=num_heads,
|
|
254
|
+
drop_path=dpr[i],
|
|
255
|
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
|
256
|
+
window_size=window_size,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
embed_dim = dim_out
|
|
260
|
+
self.blocks.append(block)
|
|
261
|
+
|
|
262
|
+
self.channel_list = (
|
|
263
|
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
|
264
|
+
if return_interm_layers
|
|
265
|
+
else [self.blocks[-1].dim_out]
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if weights_path is not None:
|
|
269
|
+
with g_pathmgr.open(weights_path, "rb") as f:
|
|
270
|
+
chkpt = torch.load(f, map_location="cpu")
|
|
271
|
+
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
|
272
|
+
|
|
273
|
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
|
274
|
+
h, w = hw
|
|
275
|
+
window_embed = self.pos_embed_window
|
|
276
|
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
|
277
|
+
pos_embed = pos_embed + window_embed.tile(
|
|
278
|
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
|
279
|
+
)
|
|
280
|
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
|
281
|
+
return pos_embed
|
|
282
|
+
|
|
283
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
284
|
+
x = self.patch_embed(x)
|
|
285
|
+
# x: (B, H, W, C)
|
|
286
|
+
|
|
287
|
+
# Add pos embed
|
|
288
|
+
x = x + self._get_pos_embed(x.shape[1:3])
|
|
289
|
+
|
|
290
|
+
outputs = []
|
|
291
|
+
for i, blk in enumerate(self.blocks):
|
|
292
|
+
x = blk(x)
|
|
293
|
+
if (i == self.stage_ends[-1]) or (
|
|
294
|
+
i in self.stage_ends and self.return_interm_layers
|
|
295
|
+
):
|
|
296
|
+
feats = x.permute(0, 3, 1, 2)
|
|
297
|
+
outputs.append(feats)
|
|
298
|
+
|
|
299
|
+
return outputs
|
|
300
|
+
|
|
301
|
+
def get_layer_id(self, layer_name):
|
|
302
|
+
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
|
303
|
+
num_layers = self.get_num_layers()
|
|
304
|
+
|
|
305
|
+
if layer_name.find("rel_pos") != -1:
|
|
306
|
+
return num_layers + 1
|
|
307
|
+
elif layer_name.find("pos_embed") != -1:
|
|
308
|
+
return 0
|
|
309
|
+
elif layer_name.find("patch_embed") != -1:
|
|
310
|
+
return 0
|
|
311
|
+
elif layer_name.find("blocks") != -1:
|
|
312
|
+
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
|
313
|
+
else:
|
|
314
|
+
return num_layers + 1
|
|
315
|
+
|
|
316
|
+
def get_num_layers(self) -> int:
|
|
317
|
+
return len(self.blocks)
|