opensportslib 0.1.3__tar.gz → 0.1.3.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. {opensportslib-0.1.3/opensportslib.egg-info → opensportslib-0.1.3.dev2}/PKG-INFO +8 -4
  2. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/README.md +6 -3
  3. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/examples/quickstart/basic_classification.py +1 -1
  4. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/examples/quickstart/basic_localization.py +1 -1
  5. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/classification.yaml +14 -5
  6. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization.yaml +8 -4
  7. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/sngar-frames.yaml +7 -4
  8. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/sngar-tracking.yaml +8 -8
  9. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/trainer/classification_trainer.py +28 -10
  10. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/classification_dataset.py +8 -10
  11. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/localization_dataset.py +96 -96
  12. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/setup/setup.py +34 -10
  13. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2/opensportslib.egg-info}/PKG-INFO +8 -4
  14. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/SOURCES.txt +3 -0
  15. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/requires.txt +1 -0
  16. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/pyproject.toml +2 -2
  17. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/conftest.py +24 -6
  18. opensportslib-0.1.3.dev2/tests/test_classification_dataset_paths.py +96 -0
  19. opensportslib-0.1.3.dev2/tests/test_classification_trainer_dataloader.py +127 -0
  20. opensportslib-0.1.3.dev2/tests/test_localization_dali_filenames.py +59 -0
  21. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_task_model_api_contract.py +5 -4
  22. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/training/classification.py +3 -3
  23. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/training/localization.py +3 -3
  24. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/LICENSE +0 -0
  25. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/LICENSE-COMMERCIAL +0 -0
  26. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/MANIFEST.in +0 -0
  27. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/__init__.py +0 -0
  28. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/__init__.py +0 -0
  29. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/base_task_model.py +0 -0
  30. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/classification.py +0 -0
  31. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/localization.py +0 -0
  32. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/cli.py +0 -0
  33. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  34. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  35. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  36. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/__init__.py +0 -0
  37. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/__init__.py +0 -0
  38. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/builder.py +0 -0
  39. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/calf.py +0 -0
  40. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/ce.py +0 -0
  41. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/combine.py +0 -0
  42. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/nll.py +0 -0
  43. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/optimizer/__init__.py +0 -0
  44. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/optimizer/builder.py +0 -0
  45. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  46. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/scheduler/__init__.py +0 -0
  47. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/scheduler/builder.py +0 -0
  48. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/trainer/__init__.py +0 -0
  49. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/trainer/localization_trainer.py +0 -0
  50. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/checkpoint.py +0 -0
  51. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/config.py +0 -0
  52. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/data.py +0 -0
  53. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/ddp.py +0 -0
  54. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/default_args.py +0 -0
  55. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/lightning.py +0 -0
  56. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/load_annotations.py +0 -0
  57. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/seed.py +0 -0
  58. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/video_processing.py +0 -0
  59. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/wandb.py +0 -0
  60. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/__init__.py +0 -0
  61. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/builder.py +0 -0
  62. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/utils/__init__.py +0 -0
  63. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/utils/tracking.py +0 -0
  64. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/metrics/classification_metric.py +0 -0
  65. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/metrics/localization_metric.py +0 -0
  66. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/__init__.py +0 -0
  67. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/backbones/builder.py +0 -0
  68. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/contextaware.py +0 -0
  69. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/e2e.py +0 -0
  70. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/learnablepooling.py +0 -0
  71. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/tracking.py +0 -0
  72. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/vars.py +0 -0
  73. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/video.py +0 -0
  74. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/video_mae.py +0 -0
  75. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/builder.py +0 -0
  76. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/heads/builder.py +0 -0
  77. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/neck/builder.py +0 -0
  78. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/common.py +0 -0
  79. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/__init__.py +0 -0
  80. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/asformer.py +0 -0
  81. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/calf.py +0 -0
  82. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/gsm.py +0 -0
  83. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/gtad.py +0 -0
  84. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/tsm.py +0 -0
  85. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/litebase.py +0 -0
  86. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/modules.py +0 -0
  87. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/shift.py +0 -0
  88. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/utils.py +0 -0
  89. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/__init__.py +0 -0
  90. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/_common.py +0 -0
  91. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/hf_transfer.py +0 -0
  92. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/osl_json_to_parquet.py +0 -0
  93. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/parquet_to_osl_json.py +0 -0
  94. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/dependency_links.txt +0 -0
  95. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/entry_points.txt +0 -0
  96. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/top_level.txt +0 -0
  97. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/setup.cfg +0 -0
  98. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_config_utils_smoke.py +0 -0
  99. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_conversion_tools.py +0 -0
  100. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_hf_transfer_tools.py +0 -0
  101. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_package_smoke.py +0 -0
  102. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_public_apis_smoke.py +0 -0
  103. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_subset_train_infer_integration.py +0 -0
  104. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/build_soccernet_gar.py +0 -0
  105. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/build_soccernet_gar_action_spotting.py +0 -0
  106. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
  107. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
  108. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/download/download_hf_repo.py +0 -0
  109. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/download/download_osl_hf.py +0 -0
  110. {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/download/upload_osl_hf.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.3
3
+ Version: 0.1.3.dev2
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -27,6 +27,7 @@ Requires-Dist: pytorch-lightning
27
27
  Requires-Dist: pandas
28
28
  Requires-Dist: pyarrow
29
29
  Requires-Dist: huggingface_hub
30
+ Requires-Dist: easydict
30
31
  Provides-Extra: test
31
32
  Requires-Dist: pytest; extra == "test"
32
33
  Requires-Dist: pytest-cov; extra == "test"
@@ -112,6 +113,9 @@ Use it as the main entry point to find:
112
113
  - extracted features
113
114
  - pretrained models and checkpoints
114
115
 
116
+ See the [Model Zoo](docs/model-zoo.md) for available pretrained models,
117
+ reported scores, datasets, and loading snippets.
118
+
115
119
  --
116
120
 
117
121
  ## Quickstart
@@ -130,7 +134,7 @@ from opensportslib.apis import ClassificationModel
130
134
 
131
135
  my_model = ClassificationModel(
132
136
  config="/path/to/classification.yaml",
133
- weights="/path/to/weights.pt", # optional
137
+ weights=None, # optional: path or Hugging Face model ID
134
138
  )
135
139
 
136
140
  my_model.train(
@@ -146,7 +150,7 @@ from opensportslib.apis import ClassificationModel
146
150
 
147
151
  my_model = ClassificationModel(
148
152
  config="/path/to/classification.yaml",
149
- weights="/path/to/weights.pt", # optional
153
+ weights=None, # optional: path or Hugging Face model ID
150
154
  )
151
155
 
152
156
  predictions = my_model.infer(
@@ -177,7 +181,7 @@ from opensportslib.apis import LocalizationModel
177
181
 
178
182
  my_model = LocalizationModel(
179
183
  config="/path/to/localization.yaml",
180
- weights="/path/to/weights.pt", # optional
184
+ weights=None, # optional: path or Hugging Face model ID
181
185
  )
182
186
 
183
187
  predictions = my_model.infer(
@@ -78,6 +78,9 @@ Use it as the main entry point to find:
78
78
  - extracted features
79
79
  - pretrained models and checkpoints
80
80
 
81
+ See the [Model Zoo](docs/model-zoo.md) for available pretrained models,
82
+ reported scores, datasets, and loading snippets.
83
+
81
84
  --
82
85
 
83
86
  ## Quickstart
@@ -96,7 +99,7 @@ from opensportslib.apis import ClassificationModel
96
99
 
97
100
  my_model = ClassificationModel(
98
101
  config="/path/to/classification.yaml",
99
- weights="/path/to/weights.pt", # optional
102
+ weights=None, # optional: path or Hugging Face model ID
100
103
  )
101
104
 
102
105
  my_model.train(
@@ -112,7 +115,7 @@ from opensportslib.apis import ClassificationModel
112
115
 
113
116
  my_model = ClassificationModel(
114
117
  config="/path/to/classification.yaml",
115
- weights="/path/to/weights.pt", # optional
118
+ weights=None, # optional: path or Hugging Face model ID
116
119
  )
117
120
 
118
121
  predictions = my_model.infer(
@@ -143,7 +146,7 @@ from opensportslib.apis import LocalizationModel
143
146
 
144
147
  my_model = LocalizationModel(
145
148
  config="/path/to/localization.yaml",
146
- weights="/path/to/weights.pt", # optional
149
+ weights=None, # optional: path or Hugging Face model ID
147
150
  )
148
151
 
149
152
  predictions = my_model.infer(
@@ -9,7 +9,7 @@ def main():
9
9
 
10
10
  my_model = ClassificationModel(
11
11
  config="examples/configs/classification_video.yaml",
12
- weights="/path/to/weights.pt", # optional
12
+ weights=None, # optional: path or Hugging Face model ID
13
13
  )
14
14
 
15
15
  my_model.train(
@@ -9,7 +9,7 @@ def main():
9
9
 
10
10
  my_model = LocalizationModel(
11
11
  config="examples/configs/localization.yaml",
12
- weights="/path/to/weights.pt", # optional
12
+ weights=None, # optional: path or Hugging Face model ID
13
13
  )
14
14
 
15
15
  my_model.train(
@@ -1,35 +1,44 @@
1
1
  TASK: classification
2
2
  DATA:
3
3
  dataset_name: mvfouls
4
- data_dir: /home/vorajv/opensportslib/SoccerNet/mvfouls
4
+ data_dir: /home/giancos/datasets/OpenSportsLab/OSL-XFoul/224p
5
5
  data_modality: video
6
6
  view_type: multi # multi or single
7
7
  num_classes: 8 # mvfoul
8
8
  train:
9
9
  type: annotations_train.json
10
10
  video_path: ${DATA.data_dir}/train
11
- path: ${DATA.train.video_path}/annotations-train.json
11
+ path: ${DATA.train.video_path}/train.json
12
12
  dataloader:
13
13
  batch_size: 8
14
14
  shuffle: true
15
15
  num_workers: 4
16
16
  pin_memory: true
17
+ mp_context: spawn
18
+ persistent_workers: true
19
+ prefetch_factor: 4
17
20
  valid:
18
21
  type: annotations_valid.json
19
22
  video_path: ${DATA.data_dir}/valid
20
- path: ${DATA.valid.video_path}/annotations-valid.json
23
+ path: ${DATA.valid.video_path}/valid.json
21
24
  dataloader:
22
25
  batch_size: 1
23
26
  num_workers: 1
24
27
  shuffle: false
28
+ mp_context: spawn
29
+ persistent_workers: true
30
+ prefetch_factor: 4
25
31
  test:
26
32
  type: annotations_test.json
27
33
  video_path: ${DATA.data_dir}/test
28
- path: ${DATA.test.video_path}/annotations-test.json
34
+ path: ${DATA.test.video_path}/test.json
29
35
  dataloader:
30
36
  batch_size: 1
31
- num_workers: 1
37
+ num_workers: 0
32
38
  shuffle: false
39
+ mp_context: spawn
40
+ persistent_workers: true
41
+ prefetch_factor: 4
33
42
  num_frames: 16 # 8 before + 8 after the foul
34
43
  input_fps: 25 # Original FPS of video
35
44
  target_fps: 17 # Temporal downsampling to 1s clip (approx)
@@ -4,7 +4,7 @@ dali: True
4
4
 
5
5
  DATA:
6
6
  dataset_name: SoccerNet
7
- data_dir: /home/vorajv/opensportslib/SoccerNet/annotations/
7
+ data_dir: /home/giancos/datasets/OpenSportsLab/OSL-SNBAS/224p-2024/
8
8
  classes:
9
9
  - PASS
10
10
  - DRIVE
@@ -37,7 +37,7 @@ DATA:
37
37
  classes: ${DATA.classes}
38
38
  output_map: [data, label]
39
39
  video_path: ${DATA.data_dir}/train/
40
- path: ${DATA.train.video_path}/annotations-2024-224p-train.json
40
+ path: ${DATA.train.video_path}/train.json
41
41
  dataloader:
42
42
  batch_size: 8
43
43
  shuffle: true
@@ -49,10 +49,12 @@ DATA:
49
49
  classes: ${DATA.classes}
50
50
  output_map: [data, label]
51
51
  video_path: ${DATA.data_dir}/valid/
52
- path: ${DATA.valid.video_path}/annotations-2024-224p-valid.json
52
+ path: ${DATA.valid.video_path}/valid.json
53
53
  dataloader:
54
54
  batch_size: 8
55
55
  shuffle: true
56
+ num_workers: 4
57
+ pin_memory: true
56
58
 
57
59
  valid_data_frames:
58
60
  type: VideoGameWithDaliVideo
@@ -64,13 +66,15 @@ DATA:
64
66
  dataloader:
65
67
  batch_size: 4
66
68
  shuffle: false
69
+ num_workers: 4
70
+ pin_memory: true
67
71
 
68
72
  test:
69
73
  type: VideoGameWithDaliVideo
70
74
  classes: ${DATA.classes}
71
75
  output_map: [data, label]
72
76
  video_path: ${DATA.data_dir}/test/
73
- path: ${DATA.test.video_path}/annotations-2024-224p-test.json
77
+ path: ${DATA.test.video_path}/test.json
74
78
  results: results_spotting_test
75
79
  nms_window: 2
76
80
  metric: tight
@@ -8,13 +8,14 @@ TASK: classification
8
8
 
9
9
  DATA:
10
10
  dataset_name: sngar
11
- data_dir: /home/spark_user1/opensportslib/sngar-frames
11
+ data_dir: /home/giancos/datasets/OpenSportsLab/soccernetpro-classification-GAR/frames-parquet
12
12
  data_modality: frames_npy
13
13
  # max_samples: 100 # only used for quick testing
14
14
  num_frames: 16
15
15
  frame_size: [224, 224]
16
16
  train:
17
- path: ${DATA.data_dir}/annotations_train.json
17
+ video_path: ${DATA.data_dir}/train
18
+ path: ${DATA.data_dir}/train.json
18
19
  dataloader:
19
20
  batch_size: 8 # for frozen backbone, use 64
20
21
  # for unfrozen backbone, use 32-16-8 depending on the memory available
@@ -22,13 +23,15 @@ DATA:
22
23
  num_workers: 8
23
24
  pin_memory: true
24
25
  valid:
25
- path: ${DATA.data_dir}/annotations_valid.json
26
+ video_path: ${DATA.data_dir}/valid
27
+ path: ${DATA.data_dir}/valid.json
26
28
  dataloader:
27
29
  batch_size: 8
28
30
  num_workers: 8
29
31
  shuffle: false
30
32
  test:
31
- path: ${DATA.data_dir}/annotations_test.json
33
+ video_path: ${DATA.data_dir}/test
34
+ path: ${DATA.data_dir}/test.json
32
35
  dataloader:
33
36
  batch_size: 8
34
37
  num_workers: 8
@@ -9,7 +9,7 @@ TASK: classification
9
9
  DATA:
10
10
  dataset_name: sngar
11
11
  data_modality: tracking_parquet
12
- data_dir: /home/karkid/opensportslib/tracking-dataset
12
+ data_dir: /home/giancos/datasets/OpenSportsLab/soccernetpro-classification-GAR/tracking-parquet
13
13
  preload_data: false
14
14
  train:
15
15
  type: annotations_train.json
@@ -103,10 +103,10 @@ TRAIN:
103
103
  type: CrossEntropyLoss
104
104
 
105
105
  SYSTEM:
106
- log_dir: ./logs
107
- save_dir: ./checkpoints_tracking
108
- use_seed: true
109
- seed: 42
110
- GPU: 4
111
- device: cuda # auto | cuda | cpu
112
- gpu_id: 0
106
+ log_dir: ./logs
107
+ save_dir: ./checkpoints_tracking
108
+ use_seed: true
109
+ seed: 42
110
+ GPU: 1
111
+ device: cuda # auto | cuda | cpu
112
+ gpu_id: 0
@@ -14,6 +14,7 @@ import gc
14
14
  import json
15
15
  import time
16
16
  import logging
17
+ import multiprocessing as mp
17
18
 
18
19
  import torch
19
20
  import tqdm
@@ -865,36 +866,53 @@ class Trainer_Classification:
865
866
  else:
866
867
  val_sampler = None
867
868
 
868
- num_train_workers = self.config.DATA.train.dataloader.num_workers
869
- num_val_workers = self.config.DATA.valid.dataloader.num_workers
869
+ train_num_workers = getattr(self.config.DATA.train.dataloader, "num_workers", 0)
870
+ train_pin_memory = getattr(self.config.DATA.train.dataloader, "pin_memory", self.device.type == "cuda")
871
+ train_mp_context = getattr(self.config.DATA.train.dataloader, "mp_context", None)
872
+ train_persistent_workers = getattr(self.config.DATA.train.dataloader, "persistent_workers", train_num_workers > 0)
873
+ train_prefetch_factor = getattr(self.config.DATA.train.dataloader, "prefetch_factor", 4 if train_num_workers > 0 else None)
874
+
875
+ if train_mp_context is not None:
876
+ train_mp_context = mp.get_context(train_mp_context)
870
877
 
871
878
  train_loader = DataLoader(
872
879
  train_dataset,
873
880
  batch_size=self.config.DATA.train.dataloader.batch_size,
874
881
  shuffle=(train_sampler is None and shuffle),
875
882
  sampler=train_sampler,
876
- num_workers=num_train_workers,
877
- pin_memory=True,
883
+ num_workers=train_num_workers,
884
+ pin_memory=train_pin_memory,
878
885
  collate_fn=collate_fn,
879
886
  worker_init_fn=seed_worker,
880
887
  generator=g,
881
888
  drop_last=True,
882
- persistent_workers=num_train_workers > 0,
883
- prefetch_factor=4 if num_train_workers > 0 else None,
889
+ multiprocessing_context=train_mp_context,
890
+ persistent_workers=train_persistent_workers,
891
+ prefetch_factor=train_prefetch_factor,
884
892
  )
885
893
 
894
+ valid_num_workers = getattr(self.config.DATA.valid.dataloader, "num_workers", 0)
895
+ valid_pin_memory = getattr(self.config.DATA.valid.dataloader, "pin_memory", self.device.type == "cuda")
896
+ valid_mp_context = getattr(self.config.DATA.valid.dataloader, "mp_context", None)
897
+ valid_persistent_workers = getattr(self.config.DATA.valid.dataloader, "persistent_workers", valid_num_workers > 0)
898
+ valid_prefetch_factor = getattr(self.config.DATA.valid.dataloader, "prefetch_factor", 4 if valid_num_workers > 0 else None)
899
+
900
+ if valid_mp_context is not None:
901
+ valid_mp_context = mp.get_context(valid_mp_context)
902
+
886
903
  val_loader = DataLoader(
887
904
  val_dataset,
888
905
  batch_size=self.config.DATA.valid.dataloader.batch_size,
889
906
  shuffle=False,
890
907
  sampler=val_sampler,
891
- num_workers=num_val_workers,
892
- pin_memory=True,
908
+ num_workers=valid_num_workers,
909
+ pin_memory=valid_pin_memory,
893
910
  collate_fn=collate_fn,
894
911
  worker_init_fn=seed_worker,
895
912
  generator=g,
896
- persistent_workers=num_val_workers > 0,
897
- prefetch_factor=4 if num_val_workers > 0 else None,
913
+ multiprocessing_context=valid_mp_context,
914
+ persistent_workers=valid_persistent_workers,
915
+ prefetch_factor=valid_prefetch_factor,
898
916
  )
899
917
 
900
918
  # select the modality-specific trainer.
@@ -73,7 +73,7 @@ class ClassificationDataset(Dataset):
73
73
  self.config = config
74
74
  self.split = split
75
75
  self.exclude_labels = ["Unknown", "Dont know"]
76
- self.data_dir = config.DATA.data_dir
76
+ self.video_path = getattr(config.DATA, split).video_path #config.DATA.data_dir
77
77
  self.processor = None
78
78
 
79
79
  # view_type is optional; only MVFoul uses it as of now
@@ -197,7 +197,7 @@ class VideoDataset(ClassificationDataset):
197
197
  """
198
198
 
199
199
  def __init__(self, config, annotations_path, processor, split="train"):
200
- super().__init__(config, annotations_path, split)
200
+ super().__init__(config, annotations_path, processor, split=split)
201
201
 
202
202
  self.processor = processor
203
203
  self.view_type = getattr(config.DATA, "view_type", "single")
@@ -227,15 +227,13 @@ class VideoDataset(ClassificationDataset):
227
227
  """read a video file, temporally sub-sample, and apply transforms.
228
228
 
229
229
  Args:
230
- path: realtive path (under data_dir) to the video file.
230
+ path: realtive path (under video_path) to the video file.
231
231
 
232
232
  Returns:
233
233
  numpy.ndarray of shape (T, H, W, C).
234
234
  """
235
- full_path = os.path.join(self.config.DATA.data_dir, path)
236
-
237
- if full_path.endswith(".npy"):
238
- frames = np.load(full_path).astype(np.float32) / 255.0
235
+ if path.endswith(".npy"):
236
+ frames = np.load(os.path.join(self.video_path, path)).astype(np.float32) / 255.0
239
237
  if self.transform is not None:
240
238
  frames = self.transform(frames)
241
239
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
@@ -243,7 +241,7 @@ class VideoDataset(ClassificationDataset):
243
241
  frames = (frames - mean) / std
244
242
  return frames
245
243
 
246
- v = read_video(os.path.join(self.config.DATA.data_dir, path))
244
+ v = read_video(os.path.join(self.video_path, path))
247
245
 
248
246
  v = process_frames(
249
247
  v,
@@ -473,7 +471,7 @@ class TrackingDataset(ClassificationDataset):
473
471
  """read a single parquet tracking clip.
474
472
 
475
473
  Args:
476
- path: Relative path (under ``data_dir``) to the parquet
474
+ path: Relative path (under ``video_path``) to the parquet
477
475
  file.
478
476
 
479
477
  Returns:
@@ -481,7 +479,7 @@ class TrackingDataset(ClassificationDataset):
481
479
  """
482
480
  import pandas as pd
483
481
 
484
- full_path = os.path.join(self.data_dir, path)
482
+ full_path = os.path.join(self.video_path, path)
485
483
  return pd.read_parquet(full_path)
486
484
 
487
485
  def __getitem__(self, idx):