neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,339 @@
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 8
8
+ max_num_objects: 3
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 40
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: null # PATH to MOSE JPEGImages folder
17
+ gt_folder: null # PATH to MOSE Annotations folder
18
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
19
+ multiplier: 2
20
+
21
+ # Video transforms
22
+ vos:
23
+ train_transforms:
24
+ - _target_: training.dataset.transforms.ComposeAPI
25
+ transforms:
26
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
27
+ consistent_transform: True
28
+ - _target_: training.dataset.transforms.RandomAffine
29
+ degrees: 25
30
+ shear: 20
31
+ image_interpolation: bilinear
32
+ consistent_transform: True
33
+ - _target_: training.dataset.transforms.RandomResizeAPI
34
+ sizes: ${scratch.resolution}
35
+ square: true
36
+ consistent_transform: True
37
+ - _target_: training.dataset.transforms.ColorJitter
38
+ consistent_transform: True
39
+ brightness: 0.1
40
+ contrast: 0.03
41
+ saturation: 0.03
42
+ hue: null
43
+ - _target_: training.dataset.transforms.RandomGrayscale
44
+ p: 0.05
45
+ consistent_transform: True
46
+ - _target_: training.dataset.transforms.ColorJitter
47
+ consistent_transform: False
48
+ brightness: 0.1
49
+ contrast: 0.05
50
+ saturation: 0.05
51
+ hue: null
52
+ - _target_: training.dataset.transforms.ToTensorAPI
53
+ - _target_: training.dataset.transforms.NormalizeAPI
54
+ mean: [0.485, 0.456, 0.406]
55
+ std: [0.229, 0.224, 0.225]
56
+
57
+ trainer:
58
+ _target_: training.trainer.Trainer
59
+ mode: train_only
60
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
61
+ accelerator: cuda
62
+ seed_value: 123
63
+
64
+ model:
65
+ _target_: training.model.sam2.SAM2Train
66
+ image_encoder:
67
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
+ scalp: 1
69
+ trunk:
70
+ _target_: sam2.modeling.backbones.hieradet.Hiera
71
+ embed_dim: 112
72
+ num_heads: 2
73
+ drop_path_rate: 0.1
74
+ neck:
75
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
+ position_encoding:
77
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
+ num_pos_feats: 256
79
+ normalize: true
80
+ scale: null
81
+ temperature: 10000
82
+ d_model: 256
83
+ backbone_channel_list: [896, 448, 224, 112]
84
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
85
+ fpn_interp_model: nearest
86
+
87
+ memory_attention:
88
+ _target_: sam2.modeling.memory_attention.MemoryAttention
89
+ d_model: 256
90
+ pos_enc_at_input: true
91
+ layer:
92
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
+ activation: relu
94
+ dim_feedforward: 2048
95
+ dropout: 0.1
96
+ pos_enc_at_attn: false
97
+ self_attention:
98
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
99
+ rope_theta: 10000.0
100
+ feat_sizes: [64, 64]
101
+ embedding_dim: 256
102
+ num_heads: 1
103
+ downsample_rate: 1
104
+ dropout: 0.1
105
+ d_model: 256
106
+ pos_enc_at_cross_attn_keys: true
107
+ pos_enc_at_cross_attn_queries: false
108
+ cross_attention:
109
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
110
+ rope_theta: 10000.0
111
+ feat_sizes: [64, 64]
112
+ rope_k_repeat: True
113
+ embedding_dim: 256
114
+ num_heads: 1
115
+ downsample_rate: 1
116
+ dropout: 0.1
117
+ kv_in_dim: 64
118
+ num_layers: 4
119
+
120
+ memory_encoder:
121
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
+ out_dim: 64
123
+ position_encoding:
124
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
+ num_pos_feats: 64
126
+ normalize: true
127
+ scale: null
128
+ temperature: 10000
129
+ mask_downsampler:
130
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
+ kernel_size: 3
132
+ stride: 2
133
+ padding: 1
134
+ fuser:
135
+ _target_: sam2.modeling.memory_encoder.Fuser
136
+ layer:
137
+ _target_: sam2.modeling.memory_encoder.CXBlock
138
+ dim: 256
139
+ kernel_size: 7
140
+ padding: 3
141
+ layer_scale_init_value: 1e-6
142
+ use_dwconv: True # depth-wise convs
143
+ num_layers: 2
144
+
145
+ num_maskmem: 7
146
+ image_size: ${scratch.resolution}
147
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
148
+ sigmoid_scale_for_mem_enc: 20.0
149
+ sigmoid_bias_for_mem_enc: -10.0
150
+ use_mask_input_as_output_without_sam: true
151
+ # Memory
152
+ directly_add_no_mem_embed: true
153
+ no_obj_embed_spatial: true
154
+ # use high-resolution feature map in the SAM mask decoder
155
+ use_high_res_features_in_sam: true
156
+ # output 3 masks on the first click on initial conditioning frames
157
+ multimask_output_in_sam: true
158
+ # SAM heads
159
+ iou_prediction_use_sigmoid: True
160
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
161
+ use_obj_ptrs_in_encoder: true
162
+ add_tpos_enc_to_obj_ptrs: true
163
+ proj_tpos_enc_in_obj_ptrs: true
164
+ use_signed_tpos_enc_to_obj_ptrs: true
165
+ only_obj_ptrs_in_the_past_for_eval: true
166
+ # object occlusion prediction
167
+ pred_obj_scores: true
168
+ pred_obj_scores_mlp: true
169
+ fixed_no_obj_ptr: true
170
+ # multimask tracking settings
171
+ multimask_output_for_tracking: true
172
+ use_multimask_token_for_obj_ptr: true
173
+ multimask_min_pt_num: 0
174
+ multimask_max_pt_num: 1
175
+ use_mlp_for_obj_ptr_proj: true
176
+ # Compilation flag
177
+ # compile_image_encoder: False
178
+
179
+ ####### Training specific params #######
180
+ # box/point input and corrections
181
+ prob_to_use_pt_input_for_train: 0.5
182
+ prob_to_use_pt_input_for_eval: 0.0
183
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
184
+ prob_to_use_box_input_for_eval: 0.0
185
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
186
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
187
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
188
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
189
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
190
+ # maximum 2 initial conditioning frames
191
+ num_init_cond_frames_for_train: 2
192
+ rand_init_cond_frames_for_train: True # random 1~2
193
+ num_correction_pt_per_frame: 7
194
+ use_act_ckpt_iterative_pt_sampling: false
195
+
196
+
197
+
198
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
199
+ forward_backbone_per_frame_for_eval: True
200
+
201
+
202
+ data:
203
+ train:
204
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
205
+ phases_per_epoch: ${scratch.phases_per_epoch}
206
+ batch_sizes:
207
+ - ${scratch.train_batch_size}
208
+
209
+ datasets:
210
+ - _target_: training.dataset.utils.RepeatFactorWrapper
211
+ dataset:
212
+ _target_: training.dataset.utils.ConcatDataset
213
+ datasets:
214
+ - _target_: training.dataset.vos_dataset.VOSDataset
215
+ transforms: ${vos.train_transforms}
216
+ training: true
217
+ video_dataset:
218
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
219
+ img_folder: ${dataset.img_folder}
220
+ gt_folder: ${dataset.gt_folder}
221
+ file_list_txt: ${dataset.file_list_txt}
222
+ sampler:
223
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
224
+ num_frames: ${scratch.num_frames}
225
+ max_num_objects: ${scratch.max_num_objects}
226
+ multiplier: ${dataset.multiplier}
227
+ shuffle: True
228
+ num_workers: ${scratch.num_train_workers}
229
+ pin_memory: True
230
+ drop_last: True
231
+ collate_fn:
232
+ _target_: training.utils.data_utils.collate_fn
233
+ _partial_: true
234
+ dict_key: all
235
+
236
+ optim:
237
+ amp:
238
+ enabled: True
239
+ amp_dtype: bfloat16
240
+
241
+ optimizer:
242
+ _target_: torch.optim.AdamW
243
+
244
+ gradient_clip:
245
+ _target_: training.optimizer.GradientClipper
246
+ max_norm: 0.1
247
+ norm_type: 2
248
+
249
+ param_group_modifiers:
250
+ - _target_: training.optimizer.layer_decay_param_modifier
251
+ _partial_: True
252
+ layer_decay_value: 0.9
253
+ apply_to: 'image_encoder.trunk'
254
+ overrides:
255
+ - pattern: '*pos_embed*'
256
+ value: 1.0
257
+
258
+ options:
259
+ lr:
260
+ - scheduler:
261
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
262
+ start_value: ${scratch.base_lr}
263
+ end_value: ${divide:${scratch.base_lr},10}
264
+ - scheduler:
265
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
266
+ start_value: ${scratch.vision_lr}
267
+ end_value: ${divide:${scratch.vision_lr},10}
268
+ param_names:
269
+ - 'image_encoder.*'
270
+ weight_decay:
271
+ - scheduler:
272
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
273
+ value: 0.1
274
+ - scheduler:
275
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
276
+ value: 0.0
277
+ param_names:
278
+ - '*bias*'
279
+ module_cls_names: ['torch.nn.LayerNorm']
280
+
281
+ loss:
282
+ all:
283
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
284
+ weight_dict:
285
+ loss_mask: 20
286
+ loss_dice: 1
287
+ loss_iou: 1
288
+ loss_class: 1
289
+ supervise_all_iou: true
290
+ iou_use_l1_loss: true
291
+ pred_obj_scores: true
292
+ focal_gamma_obj_score: 0.0
293
+ focal_alpha_obj_score: -1.0
294
+
295
+ distributed:
296
+ backend: nccl
297
+ find_unused_parameters: True
298
+
299
+ logging:
300
+ tensorboard_writer:
301
+ _target_: training.utils.logger.make_tensorboard_logger
302
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
303
+ flush_secs: 120
304
+ should_log: True
305
+ log_dir: ${launcher.experiment_log_dir}/logs
306
+ log_freq: 10
307
+
308
+ # initialize from a SAM 2 checkpoint
309
+ checkpoint:
310
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
311
+ save_freq: 0 # 0 only last checkpoint is saved.
312
+ model_weight_initializer:
313
+ _partial_: True
314
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
315
+ strict: True
316
+ ignore_unexpected_keys: null
317
+ ignore_missing_keys: null
318
+
319
+ state_dict:
320
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
321
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
322
+ ckpt_state_dict_keys: ['model']
323
+
324
+ launcher:
325
+ num_nodes: 1
326
+ gpus_per_node: 8
327
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
328
+
329
+ # SLURM args if running on a cluster
330
+ submitit:
331
+ partition: null
332
+ account: null
333
+ qos: null
334
+ cpus_per_task: 10
335
+ use_cluster: false
336
+ timeout_hour: 24
337
+ name: null
338
+ port_range: [10000, 65000]
339
+
@@ -0,0 +1,335 @@
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 1
8
+ max_num_objects: 3
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 40
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: data/train/JPEGImages # PATH to MOSE JPEGImages folder
17
+ gt_folder: data/train/Annotations # PATH to MOSE Annotations folder
18
+ file_list_txt: data/train_lst.txt
19
+ multiplier: 2
20
+
21
+ # Video transforms
22
+ vos:
23
+ train_transforms:
24
+ - _target_: training.dataset.transforms.ComposeAPI
25
+ transforms:
26
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
27
+ consistent_transform: True
28
+ - _target_: training.dataset.transforms.RandomAffine
29
+ degrees: 25
30
+ shear: 20
31
+ image_interpolation: bilinear
32
+ consistent_transform: True
33
+ - _target_: training.dataset.transforms.RandomResizeAPI
34
+ sizes: ${scratch.resolution}
35
+ square: true
36
+ consistent_transform: True
37
+ - _target_: training.dataset.transforms.ColorJitter
38
+ consistent_transform: True
39
+ brightness: 0.1
40
+ contrast: 0.03
41
+ saturation: 0.03
42
+ hue: null
43
+ - _target_: training.dataset.transforms.RandomGrayscale
44
+ p: 0.05
45
+ consistent_transform: True
46
+ - _target_: training.dataset.transforms.ColorJitter
47
+ consistent_transform: False
48
+ brightness: 0.1
49
+ contrast: 0.05
50
+ saturation: 0.05
51
+ hue: null
52
+ - _target_: training.dataset.transforms.ToTensorAPI
53
+ - _target_: training.dataset.transforms.NormalizeAPI
54
+ mean: [0.485, 0.456, 0.406]
55
+ std: [0.229, 0.224, 0.225]
56
+
57
+ trainer:
58
+ _target_: training.trainer.Trainer
59
+ mode: train_only
60
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
61
+ accelerator: cuda
62
+ seed_value: 123
63
+
64
+ model:
65
+ _target_: training.model.sam2.SAM2Train
66
+ image_encoder:
67
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
+ scalp: 1
69
+ trunk:
70
+ _target_: sam2.modeling.backbones.hieradet.Hiera
71
+ embed_dim: 112
72
+ num_heads: 2
73
+ drop_path_rate: 0.1
74
+ neck:
75
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
+ position_encoding:
77
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
+ num_pos_feats: 256
79
+ normalize: true
80
+ scale: null
81
+ temperature: 10000
82
+ d_model: 256
83
+ backbone_channel_list: [896, 448, 224, 112]
84
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
85
+ fpn_interp_model: nearest
86
+
87
+ memory_attention:
88
+ _target_: sam2.modeling.memory_attention.MemoryAttention
89
+ d_model: 256
90
+ pos_enc_at_input: true
91
+ layer:
92
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
+ activation: relu
94
+ dim_feedforward: 2048
95
+ dropout: 0.1
96
+ pos_enc_at_attn: false
97
+ self_attention:
98
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
99
+ rope_theta: 10000.0
100
+ feat_sizes: [32, 32]
101
+ embedding_dim: 256
102
+ num_heads: 1
103
+ downsample_rate: 1
104
+ dropout: 0.1
105
+ d_model: 256
106
+ pos_enc_at_cross_attn_keys: true
107
+ pos_enc_at_cross_attn_queries: false
108
+ cross_attention:
109
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
110
+ rope_theta: 10000.0
111
+ feat_sizes: [32, 32]
112
+ rope_k_repeat: True
113
+ embedding_dim: 256
114
+ num_heads: 1
115
+ downsample_rate: 1
116
+ dropout: 0.1
117
+ kv_in_dim: 64
118
+ num_layers: 4
119
+
120
+ memory_encoder:
121
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
+ out_dim: 64
123
+ position_encoding:
124
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
+ num_pos_feats: 64
126
+ normalize: true
127
+ scale: null
128
+ temperature: 10000
129
+ mask_downsampler:
130
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
+ kernel_size: 3
132
+ stride: 2
133
+ padding: 1
134
+ fuser:
135
+ _target_: sam2.modeling.memory_encoder.Fuser
136
+ layer:
137
+ _target_: sam2.modeling.memory_encoder.CXBlock
138
+ dim: 256
139
+ kernel_size: 7
140
+ padding: 3
141
+ layer_scale_init_value: 1e-6
142
+ use_dwconv: True # depth-wise convs
143
+ num_layers: 2
144
+
145
+ num_maskmem: 7
146
+ image_size: ${scratch.resolution}
147
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
148
+ sigmoid_scale_for_mem_enc: 20.0
149
+ sigmoid_bias_for_mem_enc: -10.0
150
+ use_mask_input_as_output_without_sam: true
151
+ # Memory
152
+ directly_add_no_mem_embed: true
153
+ no_obj_embed_spatial: true
154
+ # use high-resolution feature map in the SAM mask decoder
155
+ use_high_res_features_in_sam: true
156
+ # output 3 masks on the first click on initial conditioning frames
157
+ multimask_output_in_sam: true
158
+ # SAM heads
159
+ iou_prediction_use_sigmoid: True
160
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
161
+ use_obj_ptrs_in_encoder: true
162
+ add_tpos_enc_to_obj_ptrs: true
163
+ proj_tpos_enc_in_obj_ptrs: true
164
+ use_signed_tpos_enc_to_obj_ptrs: true
165
+ only_obj_ptrs_in_the_past_for_eval: true
166
+ # object occlusion prediction
167
+ pred_obj_scores: true
168
+ pred_obj_scores_mlp: true
169
+ fixed_no_obj_ptr: true
170
+ # multimask tracking settings
171
+ multimask_output_for_tracking: true
172
+ use_multimask_token_for_obj_ptr: true
173
+ multimask_min_pt_num: 0
174
+ multimask_max_pt_num: 1
175
+ use_mlp_for_obj_ptr_proj: true
176
+ # Compilation flag
177
+ # compile_image_encoder: False
178
+
179
+ ####### Training specific params #######
180
+ # box/point input and corrections
181
+ prob_to_use_pt_input_for_train: 0.5
182
+ prob_to_use_pt_input_for_eval: 0.0
183
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
184
+ prob_to_use_box_input_for_eval: 0.0
185
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
186
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
187
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
188
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
189
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
190
+ # maximum 2 initial conditioning frames
191
+ num_init_cond_frames_for_train: 2
192
+ rand_init_cond_frames_for_train: True # random 1~2
193
+ num_correction_pt_per_frame: 7
194
+ use_act_ckpt_iterative_pt_sampling: false
195
+
196
+
197
+
198
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
199
+ forward_backbone_per_frame_for_eval: True
200
+
201
+
202
+ data:
203
+ train:
204
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
205
+ phases_per_epoch: ${scratch.phases_per_epoch}
206
+ batch_sizes:
207
+ - ${scratch.train_batch_size}
208
+
209
+ datasets:
210
+ - _target_: training.dataset.vos_dataset.VOSDataset
211
+ transforms: ${vos.train_transforms}
212
+ training: true
213
+ video_dataset:
214
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
215
+ img_folder: ${dataset.img_folder}
216
+ gt_folder: ${dataset.gt_folder}
217
+ file_list_txt: ${dataset.file_list_txt}
218
+ multiplier: ${dataset.multiplier}
219
+ sampler:
220
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
221
+ num_frames: 1
222
+ max_num_objects: ${scratch.max_num_objects}
223
+ shuffle: True
224
+ num_workers: ${scratch.num_train_workers}
225
+ pin_memory: True
226
+ drop_last: True
227
+ collate_fn:
228
+ _target_: training.utils.data_utils.collate_fn
229
+ _partial_: true
230
+ dict_key: all
231
+
232
+ optim:
233
+ amp:
234
+ enabled: True
235
+ amp_dtype: bfloat16
236
+
237
+ optimizer:
238
+ _target_: torch.optim.AdamW
239
+
240
+ gradient_clip:
241
+ _target_: training.optimizer.GradientClipper
242
+ max_norm: 0.1
243
+ norm_type: 2
244
+
245
+ param_group_modifiers:
246
+ - _target_: training.optimizer.layer_decay_param_modifier
247
+ _partial_: True
248
+ layer_decay_value: 0.9
249
+ apply_to: 'image_encoder.trunk'
250
+ overrides:
251
+ - pattern: '*pos_embed*'
252
+ value: 1.0
253
+
254
+ options:
255
+ lr:
256
+ - scheduler:
257
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
258
+ start_value: ${scratch.base_lr}
259
+ end_value: ${divide:${scratch.base_lr},10}
260
+ - scheduler:
261
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
262
+ start_value: ${scratch.vision_lr}
263
+ end_value: ${divide:${scratch.vision_lr},10}
264
+ param_names:
265
+ - 'image_encoder.*'
266
+ weight_decay:
267
+ - scheduler:
268
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
269
+ value: 0.1
270
+ - scheduler:
271
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
272
+ value: 0.0
273
+ param_names:
274
+ - '*bias*'
275
+ module_cls_names: ['torch.nn.LayerNorm']
276
+
277
+ loss:
278
+ all:
279
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
280
+ weight_dict:
281
+ loss_mask: 20
282
+ loss_dice: 1
283
+ loss_iou: 1
284
+ loss_class: 1
285
+ supervise_all_iou: true
286
+ iou_use_l1_loss: true
287
+ pred_obj_scores: true
288
+ focal_gamma_obj_score: 0.0
289
+ focal_alpha_obj_score: -1.0
290
+
291
+ distributed:
292
+ backend: nccl
293
+ find_unused_parameters: True
294
+
295
+ logging:
296
+ tensorboard_writer:
297
+ _target_: training.utils.logger.make_tensorboard_logger
298
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
299
+ flush_secs: 120
300
+ should_log: True
301
+ log_dir: ${launcher.experiment_log_dir}/logs
302
+ log_freq: 10
303
+
304
+ # initialize from a SAM 2 checkpoint
305
+ checkpoint:
306
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
307
+ save_freq: 0 # 0 only last checkpoint is saved.
308
+ model_weight_initializer:
309
+ _partial_: True
310
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
311
+ strict: True
312
+ ignore_unexpected_keys: null
313
+ ignore_missing_keys: null
314
+
315
+ state_dict:
316
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
317
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
318
+ ckpt_state_dict_keys: ['model']
319
+
320
+ launcher:
321
+ num_nodes: 1
322
+ gpus_per_node: 8
323
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
324
+
325
+ # SLURM args if running on a cluster
326
+ submitit:
327
+ partition: null
328
+ account: null
329
+ qos: null
330
+ cpus_per_task: 10
331
+ use_cluster: false
332
+ timeout_hour: 24
333
+ name: null
334
+ port_range: [10000, 65000]
335
+
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.