singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,339 @@
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 8
8
+ max_num_objects: 3
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 40
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: null # PATH to MOSE JPEGImages folder
17
+ gt_folder: null # PATH to MOSE Annotations folder
18
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
19
+ multiplier: 2
20
+
21
+ # Video transforms
22
+ vos:
23
+ train_transforms:
24
+ - _target_: training.dataset.transforms.ComposeAPI
25
+ transforms:
26
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
27
+ consistent_transform: True
28
+ - _target_: training.dataset.transforms.RandomAffine
29
+ degrees: 25
30
+ shear: 20
31
+ image_interpolation: bilinear
32
+ consistent_transform: True
33
+ - _target_: training.dataset.transforms.RandomResizeAPI
34
+ sizes: ${scratch.resolution}
35
+ square: true
36
+ consistent_transform: True
37
+ - _target_: training.dataset.transforms.ColorJitter
38
+ consistent_transform: True
39
+ brightness: 0.1
40
+ contrast: 0.03
41
+ saturation: 0.03
42
+ hue: null
43
+ - _target_: training.dataset.transforms.RandomGrayscale
44
+ p: 0.05
45
+ consistent_transform: True
46
+ - _target_: training.dataset.transforms.ColorJitter
47
+ consistent_transform: False
48
+ brightness: 0.1
49
+ contrast: 0.05
50
+ saturation: 0.05
51
+ hue: null
52
+ - _target_: training.dataset.transforms.ToTensorAPI
53
+ - _target_: training.dataset.transforms.NormalizeAPI
54
+ mean: [0.485, 0.456, 0.406]
55
+ std: [0.229, 0.224, 0.225]
56
+
57
+ trainer:
58
+ _target_: training.trainer.Trainer
59
+ mode: train_only
60
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
61
+ accelerator: cuda
62
+ seed_value: 123
63
+
64
+ model:
65
+ _target_: training.model.sam2.SAM2Train
66
+ image_encoder:
67
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
+ scalp: 1
69
+ trunk:
70
+ _target_: sam2.modeling.backbones.hieradet.Hiera
71
+ embed_dim: 112
72
+ num_heads: 2
73
+ drop_path_rate: 0.1
74
+ neck:
75
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
+ position_encoding:
77
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
+ num_pos_feats: 256
79
+ normalize: true
80
+ scale: null
81
+ temperature: 10000
82
+ d_model: 256
83
+ backbone_channel_list: [896, 448, 224, 112]
84
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
85
+ fpn_interp_model: nearest
86
+
87
+ memory_attention:
88
+ _target_: sam2.modeling.memory_attention.MemoryAttention
89
+ d_model: 256
90
+ pos_enc_at_input: true
91
+ layer:
92
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
+ activation: relu
94
+ dim_feedforward: 2048
95
+ dropout: 0.1
96
+ pos_enc_at_attn: false
97
+ self_attention:
98
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
99
+ rope_theta: 10000.0
100
+ feat_sizes: [64, 64]
101
+ embedding_dim: 256
102
+ num_heads: 1
103
+ downsample_rate: 1
104
+ dropout: 0.1
105
+ d_model: 256
106
+ pos_enc_at_cross_attn_keys: true
107
+ pos_enc_at_cross_attn_queries: false
108
+ cross_attention:
109
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
110
+ rope_theta: 10000.0
111
+ feat_sizes: [64, 64]
112
+ rope_k_repeat: True
113
+ embedding_dim: 256
114
+ num_heads: 1
115
+ downsample_rate: 1
116
+ dropout: 0.1
117
+ kv_in_dim: 64
118
+ num_layers: 4
119
+
120
+ memory_encoder:
121
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
+ out_dim: 64
123
+ position_encoding:
124
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
+ num_pos_feats: 64
126
+ normalize: true
127
+ scale: null
128
+ temperature: 10000
129
+ mask_downsampler:
130
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
+ kernel_size: 3
132
+ stride: 2
133
+ padding: 1
134
+ fuser:
135
+ _target_: sam2.modeling.memory_encoder.Fuser
136
+ layer:
137
+ _target_: sam2.modeling.memory_encoder.CXBlock
138
+ dim: 256
139
+ kernel_size: 7
140
+ padding: 3
141
+ layer_scale_init_value: 1e-6
142
+ use_dwconv: True # depth-wise convs
143
+ num_layers: 2
144
+
145
+ num_maskmem: 7
146
+ image_size: ${scratch.resolution}
147
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
148
+ sigmoid_scale_for_mem_enc: 20.0
149
+ sigmoid_bias_for_mem_enc: -10.0
150
+ use_mask_input_as_output_without_sam: true
151
+ # Memory
152
+ directly_add_no_mem_embed: true
153
+ no_obj_embed_spatial: true
154
+ # use high-resolution feature map in the SAM mask decoder
155
+ use_high_res_features_in_sam: true
156
+ # output 3 masks on the first click on initial conditioning frames
157
+ multimask_output_in_sam: true
158
+ # SAM heads
159
+ iou_prediction_use_sigmoid: True
160
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
161
+ use_obj_ptrs_in_encoder: true
162
+ add_tpos_enc_to_obj_ptrs: true
163
+ proj_tpos_enc_in_obj_ptrs: true
164
+ use_signed_tpos_enc_to_obj_ptrs: true
165
+ only_obj_ptrs_in_the_past_for_eval: true
166
+ # object occlusion prediction
167
+ pred_obj_scores: true
168
+ pred_obj_scores_mlp: true
169
+ fixed_no_obj_ptr: true
170
+ # multimask tracking settings
171
+ multimask_output_for_tracking: true
172
+ use_multimask_token_for_obj_ptr: true
173
+ multimask_min_pt_num: 0
174
+ multimask_max_pt_num: 1
175
+ use_mlp_for_obj_ptr_proj: true
176
+ # Compilation flag
177
+ # compile_image_encoder: False
178
+
179
+ ####### Training specific params #######
180
+ # box/point input and corrections
181
+ prob_to_use_pt_input_for_train: 0.5
182
+ prob_to_use_pt_input_for_eval: 0.0
183
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
184
+ prob_to_use_box_input_for_eval: 0.0
185
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
186
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
187
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
188
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
189
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
190
+ # maximum 2 initial conditioning frames
191
+ num_init_cond_frames_for_train: 2
192
+ rand_init_cond_frames_for_train: True # random 1~2
193
+ num_correction_pt_per_frame: 7
194
+ use_act_ckpt_iterative_pt_sampling: false
195
+
196
+
197
+
198
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
199
+ forward_backbone_per_frame_for_eval: True
200
+
201
+
202
+ data:
203
+ train:
204
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
205
+ phases_per_epoch: ${scratch.phases_per_epoch}
206
+ batch_sizes:
207
+ - ${scratch.train_batch_size}
208
+
209
+ datasets:
210
+ - _target_: training.dataset.utils.RepeatFactorWrapper
211
+ dataset:
212
+ _target_: training.dataset.utils.ConcatDataset
213
+ datasets:
214
+ - _target_: training.dataset.vos_dataset.VOSDataset
215
+ transforms: ${vos.train_transforms}
216
+ training: true
217
+ video_dataset:
218
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
219
+ img_folder: ${dataset.img_folder}
220
+ gt_folder: ${dataset.gt_folder}
221
+ file_list_txt: ${dataset.file_list_txt}
222
+ sampler:
223
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
224
+ num_frames: ${scratch.num_frames}
225
+ max_num_objects: ${scratch.max_num_objects}
226
+ multiplier: ${dataset.multiplier}
227
+ shuffle: True
228
+ num_workers: ${scratch.num_train_workers}
229
+ pin_memory: True
230
+ drop_last: True
231
+ collate_fn:
232
+ _target_: training.utils.data_utils.collate_fn
233
+ _partial_: true
234
+ dict_key: all
235
+
236
+ optim:
237
+ amp:
238
+ enabled: True
239
+ amp_dtype: bfloat16
240
+
241
+ optimizer:
242
+ _target_: torch.optim.AdamW
243
+
244
+ gradient_clip:
245
+ _target_: training.optimizer.GradientClipper
246
+ max_norm: 0.1
247
+ norm_type: 2
248
+
249
+ param_group_modifiers:
250
+ - _target_: training.optimizer.layer_decay_param_modifier
251
+ _partial_: True
252
+ layer_decay_value: 0.9
253
+ apply_to: 'image_encoder.trunk'
254
+ overrides:
255
+ - pattern: '*pos_embed*'
256
+ value: 1.0
257
+
258
+ options:
259
+ lr:
260
+ - scheduler:
261
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
262
+ start_value: ${scratch.base_lr}
263
+ end_value: ${divide:${scratch.base_lr},10}
264
+ - scheduler:
265
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
266
+ start_value: ${scratch.vision_lr}
267
+ end_value: ${divide:${scratch.vision_lr},10}
268
+ param_names:
269
+ - 'image_encoder.*'
270
+ weight_decay:
271
+ - scheduler:
272
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
273
+ value: 0.1
274
+ - scheduler:
275
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
276
+ value: 0.0
277
+ param_names:
278
+ - '*bias*'
279
+ module_cls_names: ['torch.nn.LayerNorm']
280
+
281
+ loss:
282
+ all:
283
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
284
+ weight_dict:
285
+ loss_mask: 20
286
+ loss_dice: 1
287
+ loss_iou: 1
288
+ loss_class: 1
289
+ supervise_all_iou: true
290
+ iou_use_l1_loss: true
291
+ pred_obj_scores: true
292
+ focal_gamma_obj_score: 0.0
293
+ focal_alpha_obj_score: -1.0
294
+
295
+ distributed:
296
+ backend: nccl
297
+ find_unused_parameters: True
298
+
299
+ logging:
300
+ tensorboard_writer:
301
+ _target_: training.utils.logger.make_tensorboard_logger
302
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
303
+ flush_secs: 120
304
+ should_log: True
305
+ log_dir: ${launcher.experiment_log_dir}/logs
306
+ log_freq: 10
307
+
308
+ # initialize from a SAM 2 checkpoint
309
+ checkpoint:
310
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
311
+ save_freq: 0 # 0 only last checkpoint is saved.
312
+ model_weight_initializer:
313
+ _partial_: True
314
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
315
+ strict: True
316
+ ignore_unexpected_keys: null
317
+ ignore_missing_keys: null
318
+
319
+ state_dict:
320
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
321
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
322
+ ckpt_state_dict_keys: ['model']
323
+
324
+ launcher:
325
+ num_nodes: 1
326
+ gpus_per_node: 8
327
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
328
+
329
+ # SLURM args if running on a cluster
330
+ submitit:
331
+ partition: null
332
+ account: null
333
+ qos: null
334
+ cpus_per_task: 10
335
+ use_cluster: false
336
+ timeout_hour: 24
337
+ name: null
338
+ port_range: [10000, 65000]
339
+
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
@@ -0,0 +1,317 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from functools import partial
9
+ from typing import List, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from iopath.common.file_io import g_pathmgr
15
+
16
+ from sam2.modeling.backbones.utils import (
17
+ PatchEmbed,
18
+ window_partition,
19
+ window_unpartition,
20
+ )
21
+
22
+ from sam2.modeling.sam2_utils import DropPath, MLP
23
+
24
+
25
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26
+ if pool is None:
27
+ return x
28
+ # (B, H, W, C) -> (B, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+ x = pool(x)
31
+ # (B, C, H', W') -> (B, H', W', C)
32
+ x = x.permute(0, 2, 3, 1)
33
+ if norm:
34
+ x = norm(x)
35
+
36
+ return x
37
+
38
+
39
+ class MultiScaleAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ dim_out: int,
44
+ num_heads: int,
45
+ q_pool: nn.Module = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.dim = dim
50
+ self.dim_out = dim_out
51
+ self.num_heads = num_heads
52
+ self.q_pool = q_pool
53
+ self.qkv = nn.Linear(dim, dim_out * 3)
54
+ self.proj = nn.Linear(dim_out, dim_out)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ B, H, W, _ = x.shape
58
+ # qkv with shape (B, H * W, 3, nHead, C)
59
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60
+ # q, k, v with shape (B, H * W, nheads, C)
61
+ q, k, v = torch.unbind(qkv, 2)
62
+
63
+ # Q pooling (for downsample at stage changes)
64
+ if self.q_pool:
65
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66
+ H, W = q.shape[1:3] # downsampled shape
67
+ q = q.reshape(B, H * W, self.num_heads, -1)
68
+
69
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70
+ x = F.scaled_dot_product_attention(
71
+ q.transpose(1, 2),
72
+ k.transpose(1, 2),
73
+ v.transpose(1, 2),
74
+ )
75
+ # Transpose back
76
+ x = x.transpose(1, 2)
77
+ x = x.reshape(B, H, W, -1)
78
+
79
+ x = self.proj(x)
80
+
81
+ return x
82
+
83
+
84
+ class MultiScaleBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ drop_path: float = 0.0,
92
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
93
+ q_stride: Tuple[int, int] = None,
94
+ act_layer: nn.Module = nn.GELU,
95
+ window_size: int = 0,
96
+ ):
97
+ super().__init__()
98
+
99
+ if isinstance(norm_layer, str):
100
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101
+
102
+ self.dim = dim
103
+ self.dim_out = dim_out
104
+ self.norm1 = norm_layer(dim)
105
+
106
+ self.window_size = window_size
107
+
108
+ self.pool, self.q_stride = None, q_stride
109
+ if self.q_stride:
110
+ self.pool = nn.MaxPool2d(
111
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
112
+ )
113
+
114
+ self.attn = MultiScaleAttention(
115
+ dim,
116
+ dim_out,
117
+ num_heads=num_heads,
118
+ q_pool=self.pool,
119
+ )
120
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121
+
122
+ self.norm2 = norm_layer(dim_out)
123
+ self.mlp = MLP(
124
+ dim_out,
125
+ int(dim_out * mlp_ratio),
126
+ dim_out,
127
+ num_layers=2,
128
+ activation=act_layer,
129
+ )
130
+
131
+ if dim != dim_out:
132
+ self.proj = nn.Linear(dim, dim_out)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ shortcut = x # B, H, W, C
136
+ x = self.norm1(x)
137
+
138
+ # Skip connection
139
+ if self.dim != self.dim_out:
140
+ shortcut = do_pool(self.proj(x), self.pool)
141
+
142
+ # Window partition
143
+ window_size = self.window_size
144
+ if window_size > 0:
145
+ H, W = x.shape[1], x.shape[2]
146
+ x, pad_hw = window_partition(x, window_size)
147
+
148
+ # Window Attention + Q Pooling (if stage change)
149
+ x = self.attn(x)
150
+ if self.q_stride:
151
+ # Shapes have changed due to Q pooling
152
+ window_size = self.window_size // self.q_stride[0]
153
+ H, W = shortcut.shape[1:3]
154
+
155
+ pad_h = (window_size - H % window_size) % window_size
156
+ pad_w = (window_size - W % window_size) % window_size
157
+ pad_hw = (H + pad_h, W + pad_w)
158
+
159
+ # Reverse window partition
160
+ if self.window_size > 0:
161
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
162
+
163
+ x = shortcut + self.drop_path(x)
164
+ # MLP
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class Hiera(nn.Module):
170
+ """
171
+ Reference: https://arxiv.org/abs/2306.00989
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ embed_dim: int = 96, # initial embed dim
177
+ num_heads: int = 1, # initial number of heads
178
+ drop_path_rate: float = 0.0, # stochastic depth
179
+ q_pool: int = 3, # number of q_pool stages
180
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
183
+ head_mul: float = 2.0, # head_mul factor at stage shift
184
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185
+ # window size per stage, when not using global att.
186
+ window_spec: Tuple[int, ...] = (
187
+ 8,
188
+ 4,
189
+ 14,
190
+ 7,
191
+ ),
192
+ # global attn in these blocks
193
+ global_att_blocks: Tuple[int, ...] = (
194
+ 12,
195
+ 16,
196
+ 20,
197
+ ),
198
+ weights_path=None,
199
+ return_interm_layers=True, # return feats from every stage
200
+ ):
201
+ super().__init__()
202
+
203
+ assert len(stages) == len(window_spec)
204
+ self.window_spec = window_spec
205
+
206
+ depth = sum(stages)
207
+ self.q_stride = q_stride
208
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
210
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211
+ self.return_interm_layers = return_interm_layers
212
+
213
+ self.patch_embed = PatchEmbed(
214
+ embed_dim=embed_dim,
215
+ )
216
+ # Which blocks have global att?
217
+ self.global_att_blocks = global_att_blocks
218
+
219
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
+ self.pos_embed = nn.Parameter(
222
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
+ )
224
+ self.pos_embed_window = nn.Parameter(
225
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
+ )
227
+
228
+ dpr = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
+ ] # stochastic depth decay rule
231
+
232
+ cur_stage = 1
233
+ self.blocks = nn.ModuleList()
234
+
235
+ for i in range(depth):
236
+ dim_out = embed_dim
237
+ # lags by a block, so first block of
238
+ # next stage uses an initial window size
239
+ # of previous stage and final window size of current stage
240
+ window_size = self.window_spec[cur_stage - 1]
241
+
242
+ if self.global_att_blocks is not None:
243
+ window_size = 0 if i in self.global_att_blocks else window_size
244
+
245
+ if i - 1 in self.stage_ends:
246
+ dim_out = int(embed_dim * dim_mul)
247
+ num_heads = int(num_heads * head_mul)
248
+ cur_stage += 1
249
+
250
+ block = MultiScaleBlock(
251
+ dim=embed_dim,
252
+ dim_out=dim_out,
253
+ num_heads=num_heads,
254
+ drop_path=dpr[i],
255
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
256
+ window_size=window_size,
257
+ )
258
+
259
+ embed_dim = dim_out
260
+ self.blocks.append(block)
261
+
262
+ self.channel_list = (
263
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
+ if return_interm_layers
265
+ else [self.blocks[-1].dim_out]
266
+ )
267
+
268
+ if weights_path is not None:
269
+ with g_pathmgr.open(weights_path, "rb") as f:
270
+ chkpt = torch.load(f, map_location="cpu")
271
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272
+
273
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274
+ h, w = hw
275
+ window_embed = self.pos_embed_window
276
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277
+ pos_embed = pos_embed + window_embed.tile(
278
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279
+ )
280
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
281
+ return pos_embed
282
+
283
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284
+ x = self.patch_embed(x)
285
+ # x: (B, H, W, C)
286
+
287
+ # Add pos embed
288
+ x = x + self._get_pos_embed(x.shape[1:3])
289
+
290
+ outputs = []
291
+ for i, blk in enumerate(self.blocks):
292
+ x = blk(x)
293
+ if (i == self.stage_ends[-1]) or (
294
+ i in self.stage_ends and self.return_interm_layers
295
+ ):
296
+ feats = x.permute(0, 3, 1, 2)
297
+ outputs.append(feats)
298
+
299
+ return outputs
300
+
301
+ def get_layer_id(self, layer_name):
302
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303
+ num_layers = self.get_num_layers()
304
+
305
+ if layer_name.find("rel_pos") != -1:
306
+ return num_layers + 1
307
+ elif layer_name.find("pos_embed") != -1:
308
+ return 0
309
+ elif layer_name.find("patch_embed") != -1:
310
+ return 0
311
+ elif layer_name.find("blocks") != -1:
312
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313
+ else:
314
+ return num_layers + 1
315
+
316
+ def get_num_layers(self) -> int:
317
+ return len(self.blocks)