opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,104 @@
1
+ TASK: classification
2
+ DATA:
3
+ dataset_name: mvfouls
4
+ data_dir: /home/vorajv/opensportslib/SoccerNet/mvfouls
5
+ data_modality: video
6
+ view_type: multi # multi or single
7
+ num_classes: 8 # mvfoul
8
+ train:
9
+ type: annotations_train.json
10
+ video_path: ${DATA.data_dir}/train
11
+ path: ${DATA.train.video_path}/annotations-train.json
12
+ dataloader:
13
+ batch_size: 8
14
+ shuffle: true
15
+ num_workers: 4
16
+ pin_memory: true
17
+ valid:
18
+ type: annotations_valid.json
19
+ video_path: ${DATA.data_dir}/valid
20
+ path: ${DATA.valid.video_path}/annotations-valid.json
21
+ dataloader:
22
+ batch_size: 1
23
+ num_workers: 1
24
+ shuffle: false
25
+ test:
26
+ type: annotations_test.json
27
+ video_path: ${DATA.data_dir}/test
28
+ path: ${DATA.test.video_path}/annotations-test.json
29
+ dataloader:
30
+ batch_size: 1
31
+ num_workers: 1
32
+ shuffle: false
33
+ num_frames: 16 # 8 before + 8 after the foul
34
+ input_fps: 25 # Original FPS of video
35
+ target_fps: 17 # Temporal downsampling to 1s clip (approx)
36
+ start_frame: 63 # Start frame of clip relative to foul frame
37
+ end_frame: 87 # End frame of clip relative to foul frame
38
+ frame_size: [224, 224] # Spatial resolution (HxW)
39
+ augmentations:
40
+ random_affine: true
41
+ translate: [0.1, 0.1]
42
+ affine_scale: [0.9, 1.0]
43
+ random_perspective: true
44
+ distortion_scale: 0.3
45
+ perspective_prob: 0.5
46
+ random_rotation: true
47
+ rotation_degrees: 5
48
+ color_jitter: true
49
+ jitter_params: [0.2, 0.2, 0.2, 0.1] # brightness, contrast, saturation, hue
50
+ random_horizontal_flip: true
51
+ flip_prob: 0.5
52
+ random_crop: false
53
+
54
+ MODEL:
55
+ type: custom # huggingface, custom
56
+ backbone:
57
+ type: mvit_v2_s # video_mae, r3d_18, mc3_18, r2plus1d_18, s3d, mvit_v2_s
58
+ neck:
59
+ type: MV_Aggregate
60
+ agr_type: max # max, mean, attention
61
+ head:
62
+ type: MV_LinearLayer
63
+ pretrained_model: mvit_v2_s # MCG-NJU/videomae-base, OpenGVLab/VideoMAEv2-Base, r3d_18, mc3_18, r2plus1d_18, s3d, mvit_v2_s
64
+ unfreeze_head: true # for videomae backbone
65
+ unfreeze_last_n_layers: 3 # for videomae backbone
66
+
67
+
68
+ TRAIN:
69
+ monitor: balanced_accuracy # balanced_accuracy, loss
70
+ mode: max # max or min
71
+ enabled: true
72
+ use_weighted_sampler: false
73
+ use_weighted_loss: true
74
+ epochs: 20 #20
75
+ log_interval: 10
76
+ save_every: 2 #5
77
+
78
+ criterion:
79
+ type: CrossEntropyLoss
80
+
81
+ optimizer:
82
+ type: AdamW
83
+ lr: 0.0001 #0.001
84
+ backbone_lr: 0.00005
85
+ head_lr: 0.001
86
+ betas: [0.9, 0.999]
87
+ eps: 0.0000001
88
+ weight_decay: 0.001 #0.01 - videomae, 0.001 - others
89
+ amsgrad: false
90
+
91
+ scheduler:
92
+ type: StepLR
93
+ step_size: 3
94
+ gamma: 0.1
95
+
96
+ SYSTEM:
97
+ log_dir: ./logs
98
+ save_dir: ./checkpoints
99
+ use_seed: false
100
+ seed: 42
101
+ GPU: 4
102
+ device: cuda # auto | cuda | cpu
103
+ gpu_id: 0
104
+
@@ -0,0 +1,103 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_modality: tracking_parquet
6
+ data_dir: /home/karkid/opensportslib/sngar-tracking
7
+ preload_data: false
8
+ train:
9
+ type: annotations_train.json
10
+ video_path: ${DATA.data_dir}/train
11
+ path: ${DATA.train.video_path}/train.json
12
+ dataloader:
13
+ batch_size: 32
14
+ shuffle: true
15
+ num_workers: 8
16
+ pin_memory: true
17
+ valid:
18
+ type: annotations_valid.json
19
+ video_path: ${DATA.data_dir}/valid
20
+ path: ${DATA.valid.video_path}/valid.json
21
+ dataloader:
22
+ batch_size: 32
23
+ num_workers: 8
24
+ shuffle: false
25
+ test:
26
+ type: annotations_test.json
27
+ video_path: ${DATA.data_dir}/test
28
+ path: ${DATA.test.video_path}/test.json
29
+ dataloader:
30
+ batch_size: 32
31
+ num_workers: 8
32
+ shuffle: false
33
+ num_frames: 16
34
+ frame_interval: 9
35
+ augmentations:
36
+ vertical_flip: true
37
+ horizontal_flip: true
38
+ team_flip: true
39
+ normalize: true
40
+ num_objects: 23
41
+ feature_dim: 8
42
+ pitch_half_length: 85.0
43
+ pitch_half_width: 50.0
44
+ max_displacement: 110.0
45
+ max_ball_height: 30.0
46
+
47
+ MODEL:
48
+ type: custom
49
+ backbone:
50
+ type: graph_conv
51
+ encoder: graphconv
52
+ hidden_dim: 64
53
+ num_layers: 20
54
+ dropout: 0.1
55
+ neck:
56
+ type: TemporalAggregation
57
+ agr_type: maxpool
58
+ hidden_dim: 64
59
+ dropout: 0.1
60
+ use_position_encoding: true
61
+ head:
62
+ type: TrackingClassifier
63
+ hidden_dim: 64
64
+ dropout: 0.1
65
+ num_classes: 10
66
+ edge: positional
67
+ k: 8
68
+ r: 15.0
69
+
70
+ TRAIN:
71
+ monitor: loss # balanced_accuracy, loss
72
+ mode: min # max or min
73
+ enabled: true
74
+ use_weighted_sampler: true
75
+ use_weighted_loss: false
76
+ samples_per_class: 4000
77
+ epochs: 10
78
+ patience: 10
79
+ save_every: 20
80
+ detailed_results: true
81
+
82
+ optimizer:
83
+ type: Adam
84
+ lr: 0.001
85
+
86
+ scheduler:
87
+ type: ReduceLROnPlateau
88
+ mode: ${TRAIN.mode}
89
+ patience: 10
90
+ factor: 0.1
91
+ min_lr: 1e-8
92
+
93
+ criterion:
94
+ type: CrossEntropyLoss
95
+
96
+ SYSTEM:
97
+ log_dir: ./logs
98
+ save_dir: ./checkpoints_tracking
99
+ use_seed: true
100
+ seed: 42
101
+ GPU: 4
102
+ device: cuda # auto | cuda | cpu
103
+ gpu_id: 0
@@ -0,0 +1,79 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_modality: tracking_parquet
6
+ data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
+ preload_data: false
8
+ annotations:
9
+ train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
+ valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
+ test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
+ num_frames: 16
13
+ frame_interval: 9
14
+ augmentations:
15
+ vertical_flip: true
16
+ horizontal_flip: true
17
+ team_flip: true
18
+ normalize: true
19
+ num_workers: 20
20
+ train_batch_size: 32
21
+ valid_batch_size: 32
22
+ num_objects: 23
23
+ feature_dim: 8
24
+ pitch_half_length: 85.0
25
+ pitch_half_width: 50.0
26
+ max_displacement: 110.0
27
+ max_ball_height: 30.0
28
+
29
+ MODEL:
30
+ type: custom
31
+ backbone:
32
+ type: graph_conv
33
+ encoder: gin
34
+ hidden_dim: 64
35
+ num_layers: 20
36
+ dropout: 0.1
37
+ neck:
38
+ type: TemporalAggregation
39
+ agr_type: avgpool
40
+ hidden_dim: 64
41
+ dropout: 0.1
42
+ head:
43
+ type: TrackingClassifier
44
+ hidden_dim: 64
45
+ dropout: 0.1
46
+ num_classes: 10
47
+ edge: positional
48
+ k: 8
49
+
50
+ TRAIN:
51
+ enabled: true
52
+ use_weighted_sampler: true
53
+ use_weighted_loss: false
54
+ samples_per_class: 4000
55
+ epochs: 100
56
+ patience: 10
57
+ save_every: 20
58
+
59
+ optimizer:
60
+ type: Adam
61
+ lr: 0.001
62
+
63
+ scheduler:
64
+ type: ReduceLROnPlateau
65
+ mode: min
66
+ patience: 10
67
+ factor: 0.1
68
+ min_lr: 1e-8
69
+
70
+ criterion:
71
+ type: CrossEntropyLoss
72
+
73
+ save_dir: ./checkpoints_tracking
74
+
75
+ SYSTEM:
76
+ log_dir: ./logs
77
+ seed: 42
78
+ device: cuda
79
+ gpu_id: 0
@@ -0,0 +1,79 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_modality: tracking_parquet
6
+ data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
+ preload_data: false
8
+ annotations:
9
+ train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
+ valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
+ test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
+ num_frames: 16
13
+ frame_interval: 9
14
+ augmentations:
15
+ vertical_flip: true
16
+ horizontal_flip: true
17
+ team_flip: true
18
+ normalize: true
19
+ num_workers: 20
20
+ train_batch_size: 32
21
+ valid_batch_size: 32
22
+ num_objects: 23
23
+ feature_dim: 8
24
+ pitch_half_length: 85.0
25
+ pitch_half_width: 50.0
26
+ max_displacement: 110.0
27
+ max_ball_height: 30.0
28
+
29
+ MODEL:
30
+ type: custom
31
+ backbone:
32
+ type: graph_conv
33
+ encoder: gin
34
+ hidden_dim: 64
35
+ num_layers: 20
36
+ dropout: 0.1
37
+ neck:
38
+ type: TemporalAggregation
39
+ agr_type: maxpool
40
+ hidden_dim: 64
41
+ dropout: 0.1
42
+ head:
43
+ type: TrackingClassifier
44
+ hidden_dim: 64
45
+ dropout: 0.1
46
+ num_classes: 10
47
+ edge: positional
48
+ k: 8
49
+
50
+ TRAIN:
51
+ enabled: true
52
+ use_weighted_sampler: true
53
+ use_weighted_loss: false
54
+ samples_per_class: 4000
55
+ epochs: 100
56
+ patience: 10
57
+ save_every: 20
58
+
59
+ optimizer:
60
+ type: Adam
61
+ lr: 0.001
62
+
63
+ scheduler:
64
+ type: ReduceLROnPlateau
65
+ mode: min
66
+ patience: 10
67
+ factor: 0.1
68
+ min_lr: 1e-8
69
+
70
+ criterion:
71
+ type: CrossEntropyLoss
72
+
73
+ save_dir: ./checkpoints_tracking
74
+
75
+ SYSTEM:
76
+ log_dir: ./logs
77
+ seed: 42
78
+ device: cuda
79
+ gpu_id: 0
@@ -0,0 +1,79 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_modality: tracking_parquet
6
+ data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
+ preload_data: false
8
+ annotations:
9
+ train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
+ valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
+ test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
+ num_frames: 16
13
+ frame_interval: 9
14
+ augmentations:
15
+ vertical_flip: true
16
+ horizontal_flip: true
17
+ team_flip: true
18
+ normalize: true
19
+ num_workers: 20
20
+ train_batch_size: 32
21
+ valid_batch_size: 32
22
+ num_objects: 23
23
+ feature_dim: 8
24
+ pitch_half_length: 85.0
25
+ pitch_half_width: 50.0
26
+ max_displacement: 110.0
27
+ max_ball_height: 30.0
28
+
29
+ MODEL:
30
+ type: custom
31
+ backbone:
32
+ type: graph_conv
33
+ encoder: graphconv
34
+ hidden_dim: 64
35
+ num_layers: 20
36
+ dropout: 0.1
37
+ neck:
38
+ type: TemporalAggregation
39
+ agr_type: maxpool
40
+ hidden_dim: 64
41
+ dropout: 0.1
42
+ head:
43
+ type: TrackingClassifier
44
+ hidden_dim: 64
45
+ dropout: 0.1
46
+ num_classes: 10
47
+ edge: positional
48
+ k: 8
49
+
50
+ TRAIN:
51
+ enabled: true
52
+ use_weighted_sampler: true
53
+ use_weighted_loss: false
54
+ samples_per_class: 4000
55
+ epochs: 100
56
+ patience: 10
57
+ save_every: 20
58
+
59
+ optimizer:
60
+ type: Adam
61
+ lr: 0.001
62
+
63
+ scheduler:
64
+ type: ReduceLROnPlateau
65
+ mode: min
66
+ patience: 10
67
+ factor: 0.1
68
+ min_lr: 1e-8
69
+
70
+ criterion:
71
+ type: CrossEntropyLoss
72
+
73
+ save_dir: ./checkpoints_tracking
74
+
75
+ SYSTEM:
76
+ log_dir: ./logs
77
+ seed: 42
78
+ device: cuda
79
+ gpu_id: 0
@@ -0,0 +1,79 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_modality: tracking_parquet
6
+ data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
+ preload_data: false
8
+ annotations:
9
+ train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
+ valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
+ test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
+ num_frames: 16
13
+ frame_interval: 9
14
+ augmentations:
15
+ vertical_flip: true
16
+ horizontal_flip: true
17
+ team_flip: true
18
+ normalize: true
19
+ num_workers: 20
20
+ train_batch_size: 32
21
+ valid_batch_size: 32
22
+ num_objects: 23
23
+ feature_dim: 8
24
+ pitch_half_length: 85.0
25
+ pitch_half_width: 50.0
26
+ max_displacement: 110.0
27
+ max_ball_height: 30.0
28
+
29
+ MODEL:
30
+ type: custom
31
+ backbone:
32
+ type: graph_conv
33
+ encoder: sageconv
34
+ hidden_dim: 64
35
+ num_layers: 20
36
+ dropout: 0.1
37
+ neck:
38
+ type: TemporalAggregation
39
+ agr_type: maxpool
40
+ hidden_dim: 64
41
+ dropout: 0.1
42
+ head:
43
+ type: TrackingClassifier
44
+ hidden_dim: 64
45
+ dropout: 0.1
46
+ num_classes: 10
47
+ edge: positional
48
+ k: 8
49
+
50
+ TRAIN:
51
+ enabled: true
52
+ use_weighted_sampler: true
53
+ use_weighted_loss: false
54
+ samples_per_class: 4000
55
+ epochs: 100
56
+ patience: 10
57
+ save_every: 20
58
+
59
+ optimizer:
60
+ type: Adam
61
+ lr: 0.001
62
+
63
+ scheduler:
64
+ type: ReduceLROnPlateau
65
+ mode: min
66
+ patience: 10
67
+ factor: 0.1
68
+ min_lr: 1e-8
69
+
70
+ criterion:
71
+ type: CrossEntropyLoss
72
+
73
+ save_dir: ./checkpoints_tracking
74
+
75
+ SYSTEM:
76
+ log_dir: ./logs
77
+ seed: 42
78
+ device: cuda
79
+ gpu_id: 0
@@ -0,0 +1,79 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_modality: tracking_parquet
6
+ data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
+ preload_data: false
8
+ annotations:
9
+ train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
+ valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
+ test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
+ num_frames: 16
13
+ frame_interval: 9
14
+ augmentations:
15
+ vertical_flip: true
16
+ horizontal_flip: true
17
+ team_flip: true
18
+ normalize: true
19
+ num_workers: 20
20
+ train_batch_size: 32
21
+ valid_batch_size: 32
22
+ num_objects: 23
23
+ feature_dim: 8
24
+ pitch_half_length: 85.0
25
+ pitch_half_width: 50.0
26
+ max_displacement: 110.0
27
+ max_ball_height: 30.0
28
+
29
+ MODEL:
30
+ type: custom
31
+ backbone:
32
+ type: graph_conv
33
+ encoder: gin
34
+ hidden_dim: 64
35
+ num_layers: 20
36
+ dropout: 0.1
37
+ neck:
38
+ type: TemporalAggregation
39
+ agr_type: maxpool
40
+ hidden_dim: 64
41
+ dropout: 0.1
42
+ head:
43
+ type: TrackingClassifier
44
+ hidden_dim: 64
45
+ dropout: 0.1
46
+ num_classes: 10
47
+ edge: positional
48
+ k: 8
49
+
50
+ TRAIN:
51
+ enabled: true
52
+ use_weighted_sampler: true
53
+ use_weighted_loss: false
54
+ samples_per_class: 4000
55
+ epochs: 100
56
+ patience: 10
57
+ save_every: 20
58
+
59
+ optimizer:
60
+ type: Adam
61
+ lr: 0.001
62
+
63
+ scheduler:
64
+ type: ReduceLROnPlateau
65
+ mode: min
66
+ patience: 10
67
+ factor: 0.1
68
+ min_lr: 1e-8
69
+
70
+ criterion:
71
+ type: CrossEntropyLoss
72
+
73
+ save_dir: ./checkpoints_tracking
74
+
75
+ SYSTEM:
76
+ log_dir: ./logs
77
+ seed: 42
78
+ device: cuda
79
+ gpu_id: 0
@@ -0,0 +1,79 @@
1
+ TASK: classification
2
+
3
+ DATA:
4
+ dataset_name: sngar
5
+ data_modality: tracking_parquet
6
+ data_dir: /home/karkid/temporal-localization/data/tracking_dataset
7
+ preload_data: false
8
+ annotations:
9
+ train: /home/karkid/temporal-localization/data/tracking_dataset/train/train.json
10
+ valid: /home/karkid/temporal-localization/data/tracking_dataset/valid/valid.json
11
+ test: /home/karkid/temporal-localization/data/tracking_dataset/test/test.json
12
+ num_frames: 16
13
+ frame_interval: 9
14
+ augmentations:
15
+ vertical_flip: true
16
+ horizontal_flip: true
17
+ team_flip: true
18
+ normalize: true
19
+ num_workers: 20
20
+ train_batch_size: 32
21
+ valid_batch_size: 32
22
+ num_objects: 23
23
+ feature_dim: 8
24
+ pitch_half_length: 85.0
25
+ pitch_half_width: 50.0
26
+ max_displacement: 110.0
27
+ max_ball_height: 30.0
28
+
29
+ MODEL:
30
+ type: custom
31
+ backbone:
32
+ type: graph_conv
33
+ encoder: gin
34
+ hidden_dim: 64
35
+ num_layers: 20
36
+ dropout: 0.1
37
+ neck:
38
+ type: TemporalAggregation
39
+ agr_type: maxpool
40
+ hidden_dim: 64
41
+ dropout: 0.1
42
+ head:
43
+ type: TrackingClassifier
44
+ hidden_dim: 64
45
+ dropout: 0.1
46
+ num_classes: 10
47
+ edge: none
48
+ k: 8
49
+
50
+ TRAIN:
51
+ enabled: true
52
+ use_weighted_sampler: true
53
+ use_weighted_loss: false
54
+ samples_per_class: 4000
55
+ epochs: 100
56
+ patience: 10
57
+ save_every: 20
58
+
59
+ optimizer:
60
+ type: Adam
61
+ lr: 0.001
62
+
63
+ scheduler:
64
+ type: ReduceLROnPlateau
65
+ mode: min
66
+ patience: 10
67
+ factor: 0.1
68
+ min_lr: 1e-8
69
+
70
+ criterion:
71
+ type: CrossEntropyLoss
72
+
73
+ save_dir: ./checkpoints_tracking
74
+
75
+ SYSTEM:
76
+ log_dir: ./logs
77
+ seed: 42
78
+ device: cuda
79
+ gpu_id: 0