trace-tad 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. configs/__init__.py +0 -0
  2. configs/_dataset.py +98 -0
  3. configs/_model.py +52 -0
  4. configs/large.py +149 -0
  5. configs/small.py +146 -0
  6. tools/__init__.py +0 -0
  7. tools/infer.py +603 -0
  8. tools/prep_dataset.py +83 -0
  9. tools/test.py +187 -0
  10. tools/train.py +250 -0
  11. tools/tune_train.py +42 -0
  12. trace_tad/__init__.py +17 -0
  13. trace_tad/cli.py +945 -0
  14. trace_tad/config.py +179 -0
  15. trace_tad/cores/__init__.py +6 -0
  16. trace_tad/cores/eval_engine.py +341 -0
  17. trace_tad/cores/layer_decay_optimizer.py +93 -0
  18. trace_tad/cores/optimizer.py +135 -0
  19. trace_tad/cores/scheduler.py +212 -0
  20. trace_tad/cores/train_engine.py +156 -0
  21. trace_tad/data_prep.py +1183 -0
  22. trace_tad/datasets/__init__.py +8 -0
  23. trace_tad/datasets/base/__init__.py +5 -0
  24. trace_tad/datasets/base/padding_dataset.py +168 -0
  25. trace_tad/datasets/base/sliding_dataset.py +212 -0
  26. trace_tad/datasets/base/util.py +40 -0
  27. trace_tad/datasets/builder.py +65 -0
  28. trace_tad/datasets/thumos.py +148 -0
  29. trace_tad/datasets/transforms/__init__.py +17 -0
  30. trace_tad/datasets/transforms/end_to_end.py +360 -0
  31. trace_tad/datasets/transforms/formatting.py +305 -0
  32. trace_tad/datasets/transforms/loading.py +263 -0
  33. trace_tad/datasets/transforms/video_transforms.py +502 -0
  34. trace_tad/evaluations/__init__.py +5 -0
  35. trace_tad/evaluations/builder.py +29 -0
  36. trace_tad/evaluations/mAP.py +477 -0
  37. trace_tad/evaluations/precision.py +524 -0
  38. trace_tad/export.py +50 -0
  39. trace_tad/jobs/__init__.py +24 -0
  40. trace_tad/jobs/manager.py +705 -0
  41. trace_tad/jobs/models.py +126 -0
  42. trace_tad/model_artifacts.py +84 -0
  43. trace_tad/models/__init__.py +24 -0
  44. trace_tad/models/backbones/__init__.py +4 -0
  45. trace_tad/models/backbones/backbone_wrapper.py +267 -0
  46. trace_tad/models/backbones/vit_adapter.py +463 -0
  47. trace_tad/models/bricks/__init__.py +7 -0
  48. trace_tad/models/bricks/conv.py +112 -0
  49. trace_tad/models/bricks/gradient_ops.py +37 -0
  50. trace_tad/models/bricks/misc.py +21 -0
  51. trace_tad/models/bricks/sgp.py +123 -0
  52. trace_tad/models/bricks/transformer.py +608 -0
  53. trace_tad/models/builder.py +70 -0
  54. trace_tad/models/dense_heads/__init__.py +6 -0
  55. trace_tad/models/dense_heads/anchor_free_head.py +309 -0
  56. trace_tad/models/dense_heads/prior_generator/__init__.py +2 -0
  57. trace_tad/models/dense_heads/prior_generator/point_generator.py +36 -0
  58. trace_tad/models/dense_heads/tridet_bm_head.py +377 -0
  59. trace_tad/models/dense_heads/tridet_head.py +393 -0
  60. trace_tad/models/detectors/__init__.py +6 -0
  61. trace_tad/models/detectors/base.py +83 -0
  62. trace_tad/models/detectors/single_stage.py +138 -0
  63. trace_tad/models/detectors/tridet.py +194 -0
  64. trace_tad/models/detectors/tridet_bm.py +20 -0
  65. trace_tad/models/losses/__init__.py +5 -0
  66. trace_tad/models/losses/boundary_loss.py +202 -0
  67. trace_tad/models/losses/focal_loss.py +166 -0
  68. trace_tad/models/losses/iou_loss.py +47 -0
  69. trace_tad/models/necks/__init__.py +4 -0
  70. trace_tad/models/necks/fpn.py +127 -0
  71. trace_tad/models/necks/temporal_deformable_fpn.py +181 -0
  72. trace_tad/models/projections/__init__.py +2 -0
  73. trace_tad/models/projections/actionformer_proj.py +186 -0
  74. trace_tad/models/projections/tridet_proj.py +140 -0
  75. trace_tad/models/utils/__init__.py +1 -0
  76. trace_tad/models/utils/bbox_tools.py +58 -0
  77. trace_tad/models/utils/iou_tools.py +150 -0
  78. trace_tad/models/utils/misc.py +25 -0
  79. trace_tad/models/utils/post_processing/__init__.py +9 -0
  80. trace_tad/models/utils/post_processing/classifier.py +187 -0
  81. trace_tad/models/utils/post_processing/nms/__init__.py +0 -0
  82. trace_tad/models/utils/post_processing/nms/nms.py +236 -0
  83. trace_tad/models/utils/post_processing/utils.py +160 -0
  84. trace_tad/pipeline.py +375 -0
  85. trace_tad/pipeline_plan.py +488 -0
  86. trace_tad/registry.py +27 -0
  87. trace_tad/server/__init__.py +1 -0
  88. trace_tad/server/app.py +1347 -0
  89. trace_tad/server/jobs_router.py +366 -0
  90. trace_tad/static/annotator/assets/classnames.f9d2a9c9.js +6 -0
  91. trace_tad/static/annotator/assets/index.480a38ed.css +1 -0
  92. trace_tad/static/annotator/assets/index.de688db8.js +1 -0
  93. trace_tad/static/annotator/assets/lodash.5a06a1a1.js +9 -0
  94. trace_tad/static/annotator/assets/moment.40bc58bf.js +8 -0
  95. trace_tad/static/annotator/assets/runtime-dom.4eada9c7.js +21 -0
  96. trace_tad/static/annotator/assets/runtime.a4816b2b.js +1 -0
  97. trace_tad/static/annotator/assets/ui.7b72c5dc.js +8 -0
  98. trace_tad/static/annotator/index.html +32 -0
  99. trace_tad/static/annotator/trace-logo.svg +16 -0
  100. trace_tad/training_resources.py +507 -0
  101. trace_tad/utils/__init__.py +21 -0
  102. trace_tad/utils/auto_tune.py +248 -0
  103. trace_tad/utils/checkpoint.py +37 -0
  104. trace_tad/utils/ema.py +27 -0
  105. trace_tad/utils/logger.py +24 -0
  106. trace_tad/utils/misc.py +67 -0
  107. trace_tad/utils/train_tune.py +226 -0
  108. trace_tad/version.py +1 -0
  109. trace_tad/video_annotation.py +622 -0
  110. trace_tad/weights.py +143 -0
  111. trace_tad-0.2.0.dist-info/METADATA +142 -0
  112. trace_tad-0.2.0.dist-info/RECORD +116 -0
  113. trace_tad-0.2.0.dist-info/WHEEL +5 -0
  114. trace_tad-0.2.0.dist-info/entry_points.txt +2 -0
  115. trace_tad-0.2.0.dist-info/licenses/LICENSE +176 -0
  116. trace_tad-0.2.0.dist-info/top_level.txt +3 -0
configs/__init__.py ADDED
File without changes
configs/_dataset.py ADDED
@@ -0,0 +1,98 @@
1
+ annotation_path = "dataset.json"
2
+ class_map = "classmap.txt"
3
+ data_path = "."
4
+ block_list = None
5
+
6
+ window_size = 256
7
+
8
+ dataset = dict(
9
+ train=dict(
10
+ type="ThumosPaddingDataset",
11
+ ann_file=annotation_path,
12
+ subset_name="training",
13
+ block_list=block_list,
14
+ class_map=class_map,
15
+ data_path=data_path,
16
+ filter_gt=False,
17
+ feature_stride=1,
18
+ sample_stride=1,
19
+ pipeline=[
20
+ dict(type="PrepareVideoInfo", format="mp4"),
21
+ dict(type="VideoInit", num_threads=4),
22
+ dict(
23
+ type="LoadFrames",
24
+ num_clips=1,
25
+ method="random_trunc",
26
+ trunc_len=window_size,
27
+ trunc_thresh=0.5,
28
+ crop_ratio=[0.9, 1.0],
29
+ ),
30
+ dict(type="VideoDecode"),
31
+ dict(type="VideoResize", scale=(-1, 256)),
32
+ dict(type="VideoRandomResizedCrop"),
33
+ dict(type="VideoResize", scale=(224, 224)),
34
+ dict(type="VideoFlip", flip_ratio=0.5),
35
+ dict(type="VideoFormatShape", input_format="NCTHW"),
36
+ dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
37
+ dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
38
+ ],
39
+ ),
40
+ val=dict(
41
+ type="ThumosSlidingDataset",
42
+ ann_file=annotation_path,
43
+ subset_name="validation",
44
+ block_list=block_list,
45
+ class_map=class_map,
46
+ data_path=data_path,
47
+ filter_gt=False,
48
+ feature_stride=1,
49
+ sample_stride=1,
50
+ window_size=window_size,
51
+ window_overlap_ratio=0.25,
52
+ pipeline=[
53
+ dict(type="PrepareVideoInfo", format="mp4"),
54
+ dict(type="VideoInit", num_threads=4),
55
+ dict(type="LoadFrames", num_clips=1, method="sliding_window"),
56
+ dict(type="VideoDecode"),
57
+ dict(type="VideoResize", scale=(-1, 224)),
58
+ dict(type="VideoCenterCrop", crop_size=224),
59
+ dict(type="VideoFormatShape", input_format="NCTHW"),
60
+ dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
61
+ dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
62
+ ],
63
+ ),
64
+ test=dict(
65
+ type="ThumosSlidingDataset",
66
+ ann_file=annotation_path,
67
+ subset_name="validation",
68
+ block_list=block_list,
69
+ class_map=class_map,
70
+ data_path=data_path,
71
+ filter_gt=False,
72
+ test_mode=True,
73
+ feature_stride=1,
74
+ sample_stride=1,
75
+ window_size=window_size,
76
+ window_overlap_ratio=0.5,
77
+ pipeline=[
78
+ dict(type="PrepareVideoInfo", format="mp4"),
79
+ dict(type="VideoInit", num_threads=4),
80
+ dict(type="LoadFrames", num_clips=1, method="sliding_window"),
81
+ dict(type="VideoDecode"),
82
+ dict(type="VideoResize", scale=(-1, 224)),
83
+ dict(type="VideoCenterCrop", crop_size=224),
84
+ dict(type="VideoFormatShape", input_format="NCTHW"),
85
+ dict(type="ConvertToTensor", keys=["imgs"]),
86
+ dict(type="Collect", inputs="imgs", keys=["masks"]),
87
+ ],
88
+ ),
89
+ )
90
+
91
+ evaluation = dict(
92
+ type="Precision",
93
+ subset="validation",
94
+ tiou_thresholds=[0.3, 0.4, 0.5, 0.6, 0.7],
95
+ ground_truth_filename=annotation_path,
96
+ gt_fps=30.0,
97
+ eval_fps=30.0,
98
+ )
configs/_model.py ADDED
@@ -0,0 +1,52 @@
1
+ model = dict(
2
+ type="TriDet",
3
+ projection=dict(
4
+ type="TriDetProj",
5
+ in_channels=2048,
6
+ out_channels=512,
7
+ sgp_mlp_dim=768,
8
+ arch=(2, 2, 5), # layers in embed / stem / branch
9
+ downsample_type="max",
10
+ sgp_win_size=[1, 1, 1, 1, 1, 1],
11
+ k=5,
12
+ init_conv_vars=0,
13
+ conv_cfg=dict(kernel_size=3),
14
+ norm_cfg=dict(type="LN"),
15
+ path_pdrop=0.1,
16
+ use_abs_pe=True,
17
+ max_seq_len=768,
18
+ input_noise=0.0,
19
+ ),
20
+ neck=dict(
21
+ type="FPNIdentity",
22
+ in_channels=512,
23
+ out_channels=512,
24
+ num_levels=6,
25
+ ),
26
+ rpn_head=dict(
27
+ type="TriDetHead",
28
+ num_classes=3,
29
+ in_channels=512,
30
+ feat_channels=512,
31
+ num_convs=2,
32
+ cls_prior_prob=0.01,
33
+ prior_generator=dict(
34
+ type="PointGenerator",
35
+ strides=[1, 2, 4, 8, 16, 32],
36
+ regression_range=[(0, 4), (4, 8), (8, 16), (16, 32), (32, 64), (64, 10000)],
37
+ ),
38
+ loss_normalizer=100,
39
+ loss_normalizer_momentum=0.9,
40
+ center_sample="radius",
41
+ center_sample_radius=1.5,
42
+ label_smoothing=0.0,
43
+ boundary_kernel_size=3,
44
+ iou_weight_power=0.2,
45
+ num_bins=16,
46
+ loss=dict(
47
+ cls_loss=dict(type="ClassBalancedFocalLoss", beta=0.999),
48
+ reg_loss=dict(type="DIOULoss"),
49
+ iou_rate=dict(type="GIOULoss"),
50
+ ),
51
+ ),
52
+ )
configs/large.py ADDED
@@ -0,0 +1,149 @@
1
+ _base_ = [
2
+ "_dataset.py",
3
+ "_model.py",
4
+ ]
5
+
6
+ window_size = 768
7
+ scale_factor = 1
8
+ chunk_num = window_size * scale_factor // 16
9
+
10
+ dataset = dict(
11
+ train=dict(
12
+ pipeline=[
13
+ dict(type="PrepareVideoInfo", format="mp4"),
14
+ dict(type="VideoInit", num_threads=4, resize=(144, 144)),
15
+ dict(
16
+ type="LoadFrames",
17
+ num_clips=1,
18
+ method="random_trunc",
19
+ trunc_len=window_size,
20
+ trunc_thresh=0.75,
21
+ crop_ratio=[0.9, 1.0],
22
+ scale_factor=scale_factor,
23
+ ),
24
+ dict(type="VideoTemporalAugment", speed_range=[0.8, 1.2], p=0.5),
25
+ dict(type="VideoDecode"),
26
+ dict(type="VideoBatchResize", scale=(144, 144)),
27
+ dict(type="VideoFlip", flip_ratio=0.5),
28
+ dict(type="VideoColorJitter", brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
29
+ dict(type="VideoFormatShape", input_format="NCTHW"),
30
+ dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
31
+ dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
32
+ ],
33
+ ),
34
+ val=dict(
35
+ window_size=window_size,
36
+ pipeline=[
37
+ dict(type="PrepareVideoInfo", format="mp4"),
38
+ dict(type="VideoInit", num_threads=4, resize=(144, 144)),
39
+ dict(type="LoadFrames", num_clips=1, method="random_trunc", scale_factor=scale_factor),
40
+ dict(type="VideoDecode"),
41
+ dict(type="VideoBatchResize", scale=(144, 144)),
42
+ dict(type="VideoFormatShape", input_format="NCTHW"),
43
+ dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
44
+ dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
45
+ ],
46
+ ),
47
+ test=dict(
48
+ window_size=window_size,
49
+ pipeline=[
50
+ dict(type="PrepareVideoInfo", format="mp4"),
51
+ dict(type="VideoInit", num_threads=4, resize=(144, 144)),
52
+ dict(type="LoadFrames", num_clips=1, method="sliding_window", scale_factor=scale_factor),
53
+ dict(type="VideoDecode"),
54
+ dict(type="VideoBatchResize", scale=(144, 144)),
55
+ dict(type="VideoFormatShape", input_format="NCTHW"),
56
+ dict(type="ConvertToTensor", keys=["imgs"]),
57
+ dict(type="Collect", inputs="imgs", keys=["masks"]),
58
+ ],
59
+ ),
60
+ )
61
+
62
+ model = dict(
63
+ backbone=dict(
64
+ type="VisionTransformerAdapter",
65
+ img_size=224,
66
+ patch_size=16,
67
+ embed_dims=1024,
68
+ depth=24,
69
+ num_heads=16,
70
+ mlp_ratio=4,
71
+ qkv_bias=True,
72
+ drop_path_rate=0.3,
73
+ norm_cfg=dict(type="LN", eps=1e-6),
74
+ return_feat_map=True,
75
+ with_cp=True,
76
+ total_frames=window_size * scale_factor,
77
+ adapter_index=list(range(24)),
78
+ custom=dict(
79
+ pretrain="pretrained/vit-large-p16_videomaev2-k400.pth",
80
+ mean=[123.675, 116.28, 103.53],
81
+ std=[58.395, 57.12, 57.375],
82
+ pre_processing_pipeline=[
83
+ dict(type="Rearrange", keys=["frames"], ops="b n c (t1 t) h w -> (b t1) n c t h w", t1=chunk_num),
84
+ ],
85
+ post_processing_pipeline=[
86
+ dict(type="Reduce", keys=["feats"], ops="b n c t h w -> b c t", reduction="mean"),
87
+ dict(type="Rearrange", keys=["feats"], ops="(b t1) c t -> b c (t1 t)", t1=chunk_num),
88
+ dict(type="Interpolate", keys=["feats"], size=window_size),
89
+ ],
90
+ norm_eval=False,
91
+ freeze_backbone=False,
92
+ ),
93
+ ),
94
+ projection=dict(in_channels=1024, input_noise=0.0005),
95
+ )
96
+
97
+ solver = dict(
98
+ train=dict(batch_size=1, num_workers=16, persistent_workers=True, prefetch_factor=4),
99
+ val=dict(batch_size=4, num_workers=16, persistent_workers=True, prefetch_factor=4),
100
+ test=dict(batch_size=4, num_workers=16, persistent_workers=True, prefetch_factor=4),
101
+ clip_grad_norm=1,
102
+ ema=True,
103
+ amp=True,
104
+ accumulation_steps=2,
105
+ compile=False,
106
+ )
107
+
108
+ optimizer = dict(
109
+ type="AdamW",
110
+ lr=7e-5,
111
+ weight_decay=0.025,
112
+ paramwise=True,
113
+ backbone=dict(
114
+ lr=0,
115
+ weight_decay=0,
116
+ custom=[dict(name="adapter", lr=1e-4, weight_decay=0.05)],
117
+ exclude=["backbone"],
118
+ ),
119
+ )
120
+ scheduler = dict(type="LinearWarmupCosineAnnealingLR", warmup_epoch=5, max_epoch=150)
121
+
122
+ inference = dict(load_from_raw_predictions=False, save_raw_prediction=False)
123
+ post_processing = dict(
124
+ nms=dict(
125
+ use_soft_nms=True,
126
+ sigma=0.5,
127
+ max_seg_num=2000,
128
+ # min_score is a compaction threshold, not an output filter. Soft-NMS
129
+ # drops items whose decayed score falls below it — shrinking the active
130
+ # set and avoiding O(N²) work on items that will never reach the top
131
+ # `max_seg_num` anyway. Long videos aggregate 100k+ proposals across
132
+ # overlapping sliding windows; 0.05 is well below typical output cutoffs
133
+ # (which sit around 0.25+) so outputs remain bit-identical to 0.001
134
+ # while per-video NMS runs ~2× faster.
135
+ min_score=0.05,
136
+ multiclass=True,
137
+ voting_thresh=0.7,
138
+ ),
139
+ save_dict=True,
140
+ )
141
+
142
+ workflow = dict(
143
+ logging_interval=50,
144
+ checkpoint_interval=5,
145
+ val_eval_interval=5,
146
+ val_start_epoch=5,
147
+ )
148
+
149
+ work_dir = "exps/large"
configs/small.py ADDED
@@ -0,0 +1,146 @@
1
+ _base_ = [
2
+ "_dataset.py",
3
+ "_model.py",
4
+ ]
5
+
6
+ window_size = 768
7
+ scale_factor = 1
8
+ chunk_num = window_size * scale_factor // 16
9
+
10
+ dataset = dict(
11
+ train=dict(
12
+ pipeline=[
13
+ dict(type="PrepareVideoInfo", format="mp4"),
14
+ dict(type="VideoInit", num_threads=4, resize=(144, 144)),
15
+ dict(
16
+ type="LoadFrames",
17
+ num_clips=1,
18
+ method="random_trunc",
19
+ trunc_len=window_size,
20
+ trunc_thresh=0.75,
21
+ crop_ratio=[0.9, 1.0],
22
+ scale_factor=scale_factor,
23
+ ),
24
+ dict(type="VideoDecode"),
25
+ dict(type="VideoBatchResize", scale=(144, 144)),
26
+ dict(type="VideoFlip", flip_ratio=0.5),
27
+ dict(type="VideoColorJitter", brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
28
+ dict(type="VideoFormatShape", input_format="NCTHW"),
29
+ dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
30
+ dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
31
+ ],
32
+ ),
33
+ val=dict(
34
+ window_size=window_size,
35
+ pipeline=[
36
+ dict(type="PrepareVideoInfo", format="mp4"),
37
+ dict(type="VideoInit", num_threads=4, resize=(144, 144)),
38
+ dict(type="LoadFrames", num_clips=1, method="random_trunc", scale_factor=scale_factor),
39
+ dict(type="VideoDecode"),
40
+ dict(type="VideoBatchResize", scale=(144, 144)),
41
+ dict(type="VideoFormatShape", input_format="NCTHW"),
42
+ dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
43
+ dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
44
+ ],
45
+ ),
46
+ test=dict(
47
+ window_size=window_size,
48
+ pipeline=[
49
+ dict(type="PrepareVideoInfo", format="mp4"),
50
+ dict(type="VideoInit", num_threads=4, resize=(144, 144)),
51
+ dict(type="LoadFrames", num_clips=1, method="sliding_window", scale_factor=scale_factor),
52
+ dict(type="VideoDecode"),
53
+ dict(type="VideoBatchResize", scale=(144, 144)),
54
+ dict(type="VideoFormatShape", input_format="NCTHW"),
55
+ dict(type="ConvertToTensor", keys=["imgs"]),
56
+ dict(type="Collect", inputs="imgs", keys=["masks"]),
57
+ ],
58
+ ),
59
+ )
60
+
61
+ model = dict(
62
+ backbone=dict(
63
+ type="VisionTransformerAdapter",
64
+ img_size=224,
65
+ patch_size=16,
66
+ embed_dims=384,
67
+ depth=12,
68
+ num_heads=6,
69
+ mlp_ratio=4,
70
+ qkv_bias=True,
71
+ drop_path_rate=0.1,
72
+ norm_cfg=dict(type="LN", eps=1e-6),
73
+ return_feat_map=True,
74
+ with_cp=True,
75
+ total_frames=window_size * scale_factor,
76
+ adapter_index=list(range(12)),
77
+ custom=dict(
78
+ pretrain="pretrained/vit-small-p16_videomae-k400-pre_16x4x1_kinetics-400_my.pth",
79
+ mean=[123.675, 116.28, 103.53],
80
+ std=[58.395, 57.12, 57.375],
81
+ pre_processing_pipeline=[
82
+ dict(type="Rearrange", keys=["frames"], ops="b n c (t1 t) h w -> (b t1) n c t h w", t1=chunk_num),
83
+ ],
84
+ post_processing_pipeline=[
85
+ dict(type="Reduce", keys=["feats"], ops="b n c t h w -> b c t", reduction="mean"),
86
+ dict(type="Rearrange", keys=["feats"], ops="(b t1) c t -> b c (t1 t)", t1=chunk_num),
87
+ dict(type="Interpolate", keys=["feats"], size=window_size),
88
+ ],
89
+ norm_eval=False,
90
+ freeze_backbone=False,
91
+ ),
92
+ ),
93
+ projection=dict(in_channels=384, input_noise=0.0005),
94
+ )
95
+
96
+ solver = dict(
97
+ train=dict(batch_size=1, num_workers=16, persistent_workers=True, prefetch_factor=4),
98
+ val=dict(batch_size=4, num_workers=16, persistent_workers=True, prefetch_factor=4),
99
+ test=dict(batch_size=4, num_workers=16, persistent_workers=True, prefetch_factor=4),
100
+ clip_grad_norm=1,
101
+ ema=True,
102
+ amp=True,
103
+ )
104
+
105
+ optimizer = dict(
106
+ type="AdamW",
107
+ lr=7e-5,
108
+ weight_decay=0.025,
109
+ paramwise=True,
110
+ backbone=dict(
111
+ lr=0,
112
+ weight_decay=0,
113
+ custom=[dict(name="adapter", lr=1e-4, weight_decay=0.05)],
114
+ exclude=["backbone"],
115
+ ),
116
+ )
117
+ scheduler = dict(type="LinearWarmupCosineAnnealingLR", warmup_epoch=5, max_epoch=100)
118
+
119
+ inference = dict(load_from_raw_predictions=False, save_raw_prediction=False)
120
+ post_processing = dict(
121
+ nms=dict(
122
+ use_soft_nms=True,
123
+ sigma=0.5,
124
+ max_seg_num=2000,
125
+ # min_score is a compaction threshold, not an output filter. Soft-NMS
126
+ # drops items whose decayed score falls below it — shrinking the active
127
+ # set and avoiding O(N²) work on items that will never reach the top
128
+ # `max_seg_num` anyway. Long videos aggregate 100k+ proposals across
129
+ # overlapping sliding windows; 0.05 is well below typical output cutoffs
130
+ # (which sit around 0.25+) so outputs remain bit-identical to 0.001
131
+ # while per-video NMS runs ~2× faster.
132
+ min_score=0.05,
133
+ multiclass=True,
134
+ voting_thresh=0.7,
135
+ ),
136
+ save_dict=True,
137
+ )
138
+
139
+ workflow = dict(
140
+ logging_interval=50,
141
+ checkpoint_interval=2,
142
+ val_eval_interval=2,
143
+ val_start_epoch=0,
144
+ )
145
+
146
+ work_dir = "exps/small"
tools/__init__.py ADDED
File without changes