opensportslib 0.1.3__tar.gz → 0.1.3.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.3/opensportslib.egg-info → opensportslib-0.1.3.dev2}/PKG-INFO +8 -4
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/README.md +6 -3
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/examples/quickstart/basic_classification.py +1 -1
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/examples/quickstart/basic_localization.py +1 -1
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/classification.yaml +14 -5
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization.yaml +8 -4
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/sngar-frames.yaml +7 -4
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/sngar-tracking.yaml +8 -8
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/trainer/classification_trainer.py +28 -10
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/classification_dataset.py +8 -10
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/localization_dataset.py +96 -96
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/setup/setup.py +34 -10
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2/opensportslib.egg-info}/PKG-INFO +8 -4
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/SOURCES.txt +3 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/requires.txt +1 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/pyproject.toml +2 -2
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/conftest.py +24 -6
- opensportslib-0.1.3.dev2/tests/test_classification_dataset_paths.py +96 -0
- opensportslib-0.1.3.dev2/tests/test_classification_trainer_dataloader.py +127 -0
- opensportslib-0.1.3.dev2/tests/test_localization_dali_filenames.py +59 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_task_model_api_contract.py +5 -4
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/training/classification.py +3 -3
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/training/localization.py +3 -3
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/LICENSE +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/MANIFEST.in +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/base_task_model.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/classification.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/apis/localization.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/trainer/localization_trainer.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/load_annotations.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/core/utils/wandb.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/metrics/localization_metric.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/contextaware.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/learnablepooling.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/__init__.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/_common.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/hf_transfer.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/osl_json_to_parquet.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/tools/parquet_to_osl_json.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/setup.cfg +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_conversion_tools.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_hf_transfer_tools.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_package_smoke.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_public_apis_smoke.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tests/test_subset_train_infer_integration.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/build_soccernet_gar.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/build_soccernet_gar_action_spotting.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/download/download_hf_repo.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/download/download_osl_hf.py +0 -0
- {opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/tools/download/upload_osl_hf.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.3
|
|
3
|
+
Version: 0.1.3.dev2
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -27,6 +27,7 @@ Requires-Dist: pytorch-lightning
|
|
|
27
27
|
Requires-Dist: pandas
|
|
28
28
|
Requires-Dist: pyarrow
|
|
29
29
|
Requires-Dist: huggingface_hub
|
|
30
|
+
Requires-Dist: easydict
|
|
30
31
|
Provides-Extra: test
|
|
31
32
|
Requires-Dist: pytest; extra == "test"
|
|
32
33
|
Requires-Dist: pytest-cov; extra == "test"
|
|
@@ -112,6 +113,9 @@ Use it as the main entry point to find:
|
|
|
112
113
|
- extracted features
|
|
113
114
|
- pretrained models and checkpoints
|
|
114
115
|
|
|
116
|
+
See the [Model Zoo](docs/model-zoo.md) for available pretrained models,
|
|
117
|
+
reported scores, datasets, and loading snippets.
|
|
118
|
+
|
|
115
119
|
--
|
|
116
120
|
|
|
117
121
|
## Quickstart
|
|
@@ -130,7 +134,7 @@ from opensportslib.apis import ClassificationModel
|
|
|
130
134
|
|
|
131
135
|
my_model = ClassificationModel(
|
|
132
136
|
config="/path/to/classification.yaml",
|
|
133
|
-
weights=
|
|
137
|
+
weights=None, # optional: path or Hugging Face model ID
|
|
134
138
|
)
|
|
135
139
|
|
|
136
140
|
my_model.train(
|
|
@@ -146,7 +150,7 @@ from opensportslib.apis import ClassificationModel
|
|
|
146
150
|
|
|
147
151
|
my_model = ClassificationModel(
|
|
148
152
|
config="/path/to/classification.yaml",
|
|
149
|
-
weights=
|
|
153
|
+
weights=None, # optional: path or Hugging Face model ID
|
|
150
154
|
)
|
|
151
155
|
|
|
152
156
|
predictions = my_model.infer(
|
|
@@ -177,7 +181,7 @@ from opensportslib.apis import LocalizationModel
|
|
|
177
181
|
|
|
178
182
|
my_model = LocalizationModel(
|
|
179
183
|
config="/path/to/localization.yaml",
|
|
180
|
-
weights=
|
|
184
|
+
weights=None, # optional: path or Hugging Face model ID
|
|
181
185
|
)
|
|
182
186
|
|
|
183
187
|
predictions = my_model.infer(
|
|
@@ -78,6 +78,9 @@ Use it as the main entry point to find:
|
|
|
78
78
|
- extracted features
|
|
79
79
|
- pretrained models and checkpoints
|
|
80
80
|
|
|
81
|
+
See the [Model Zoo](docs/model-zoo.md) for available pretrained models,
|
|
82
|
+
reported scores, datasets, and loading snippets.
|
|
83
|
+
|
|
81
84
|
--
|
|
82
85
|
|
|
83
86
|
## Quickstart
|
|
@@ -96,7 +99,7 @@ from opensportslib.apis import ClassificationModel
|
|
|
96
99
|
|
|
97
100
|
my_model = ClassificationModel(
|
|
98
101
|
config="/path/to/classification.yaml",
|
|
99
|
-
weights=
|
|
102
|
+
weights=None, # optional: path or Hugging Face model ID
|
|
100
103
|
)
|
|
101
104
|
|
|
102
105
|
my_model.train(
|
|
@@ -112,7 +115,7 @@ from opensportslib.apis import ClassificationModel
|
|
|
112
115
|
|
|
113
116
|
my_model = ClassificationModel(
|
|
114
117
|
config="/path/to/classification.yaml",
|
|
115
|
-
weights=
|
|
118
|
+
weights=None, # optional: path or Hugging Face model ID
|
|
116
119
|
)
|
|
117
120
|
|
|
118
121
|
predictions = my_model.infer(
|
|
@@ -143,7 +146,7 @@ from opensportslib.apis import LocalizationModel
|
|
|
143
146
|
|
|
144
147
|
my_model = LocalizationModel(
|
|
145
148
|
config="/path/to/localization.yaml",
|
|
146
|
-
weights=
|
|
149
|
+
weights=None, # optional: path or Hugging Face model ID
|
|
147
150
|
)
|
|
148
151
|
|
|
149
152
|
predictions = my_model.infer(
|
|
@@ -1,35 +1,44 @@
|
|
|
1
1
|
TASK: classification
|
|
2
2
|
DATA:
|
|
3
3
|
dataset_name: mvfouls
|
|
4
|
-
data_dir: /home/
|
|
4
|
+
data_dir: /home/giancos/datasets/OpenSportsLab/OSL-XFoul/224p
|
|
5
5
|
data_modality: video
|
|
6
6
|
view_type: multi # multi or single
|
|
7
7
|
num_classes: 8 # mvfoul
|
|
8
8
|
train:
|
|
9
9
|
type: annotations_train.json
|
|
10
10
|
video_path: ${DATA.data_dir}/train
|
|
11
|
-
path: ${DATA.train.video_path}/
|
|
11
|
+
path: ${DATA.train.video_path}/train.json
|
|
12
12
|
dataloader:
|
|
13
13
|
batch_size: 8
|
|
14
14
|
shuffle: true
|
|
15
15
|
num_workers: 4
|
|
16
16
|
pin_memory: true
|
|
17
|
+
mp_context: spawn
|
|
18
|
+
persistent_workers: true
|
|
19
|
+
prefetch_factor: 4
|
|
17
20
|
valid:
|
|
18
21
|
type: annotations_valid.json
|
|
19
22
|
video_path: ${DATA.data_dir}/valid
|
|
20
|
-
path: ${DATA.valid.video_path}/
|
|
23
|
+
path: ${DATA.valid.video_path}/valid.json
|
|
21
24
|
dataloader:
|
|
22
25
|
batch_size: 1
|
|
23
26
|
num_workers: 1
|
|
24
27
|
shuffle: false
|
|
28
|
+
mp_context: spawn
|
|
29
|
+
persistent_workers: true
|
|
30
|
+
prefetch_factor: 4
|
|
25
31
|
test:
|
|
26
32
|
type: annotations_test.json
|
|
27
33
|
video_path: ${DATA.data_dir}/test
|
|
28
|
-
path: ${DATA.test.video_path}/
|
|
34
|
+
path: ${DATA.test.video_path}/test.json
|
|
29
35
|
dataloader:
|
|
30
36
|
batch_size: 1
|
|
31
|
-
num_workers:
|
|
37
|
+
num_workers: 0
|
|
32
38
|
shuffle: false
|
|
39
|
+
mp_context: spawn
|
|
40
|
+
persistent_workers: true
|
|
41
|
+
prefetch_factor: 4
|
|
33
42
|
num_frames: 16 # 8 before + 8 after the foul
|
|
34
43
|
input_fps: 25 # Original FPS of video
|
|
35
44
|
target_fps: 17 # Temporal downsampling to 1s clip (approx)
|
|
@@ -4,7 +4,7 @@ dali: True
|
|
|
4
4
|
|
|
5
5
|
DATA:
|
|
6
6
|
dataset_name: SoccerNet
|
|
7
|
-
data_dir: /home/
|
|
7
|
+
data_dir: /home/giancos/datasets/OpenSportsLab/OSL-SNBAS/224p-2024/
|
|
8
8
|
classes:
|
|
9
9
|
- PASS
|
|
10
10
|
- DRIVE
|
|
@@ -37,7 +37,7 @@ DATA:
|
|
|
37
37
|
classes: ${DATA.classes}
|
|
38
38
|
output_map: [data, label]
|
|
39
39
|
video_path: ${DATA.data_dir}/train/
|
|
40
|
-
path: ${DATA.train.video_path}/
|
|
40
|
+
path: ${DATA.train.video_path}/train.json
|
|
41
41
|
dataloader:
|
|
42
42
|
batch_size: 8
|
|
43
43
|
shuffle: true
|
|
@@ -49,10 +49,12 @@ DATA:
|
|
|
49
49
|
classes: ${DATA.classes}
|
|
50
50
|
output_map: [data, label]
|
|
51
51
|
video_path: ${DATA.data_dir}/valid/
|
|
52
|
-
path: ${DATA.valid.video_path}/
|
|
52
|
+
path: ${DATA.valid.video_path}/valid.json
|
|
53
53
|
dataloader:
|
|
54
54
|
batch_size: 8
|
|
55
55
|
shuffle: true
|
|
56
|
+
num_workers: 4
|
|
57
|
+
pin_memory: true
|
|
56
58
|
|
|
57
59
|
valid_data_frames:
|
|
58
60
|
type: VideoGameWithDaliVideo
|
|
@@ -64,13 +66,15 @@ DATA:
|
|
|
64
66
|
dataloader:
|
|
65
67
|
batch_size: 4
|
|
66
68
|
shuffle: false
|
|
69
|
+
num_workers: 4
|
|
70
|
+
pin_memory: true
|
|
67
71
|
|
|
68
72
|
test:
|
|
69
73
|
type: VideoGameWithDaliVideo
|
|
70
74
|
classes: ${DATA.classes}
|
|
71
75
|
output_map: [data, label]
|
|
72
76
|
video_path: ${DATA.data_dir}/test/
|
|
73
|
-
path: ${DATA.test.video_path}/
|
|
77
|
+
path: ${DATA.test.video_path}/test.json
|
|
74
78
|
results: results_spotting_test
|
|
75
79
|
nms_window: 2
|
|
76
80
|
metric: tight
|
|
@@ -8,13 +8,14 @@ TASK: classification
|
|
|
8
8
|
|
|
9
9
|
DATA:
|
|
10
10
|
dataset_name: sngar
|
|
11
|
-
data_dir: /home/
|
|
11
|
+
data_dir: /home/giancos/datasets/OpenSportsLab/soccernetpro-classification-GAR/frames-parquet
|
|
12
12
|
data_modality: frames_npy
|
|
13
13
|
# max_samples: 100 # only used for quick testing
|
|
14
14
|
num_frames: 16
|
|
15
15
|
frame_size: [224, 224]
|
|
16
16
|
train:
|
|
17
|
-
|
|
17
|
+
video_path: ${DATA.data_dir}/train
|
|
18
|
+
path: ${DATA.data_dir}/train.json
|
|
18
19
|
dataloader:
|
|
19
20
|
batch_size: 8 # for frozen backbone, use 64
|
|
20
21
|
# for unfrozen backbone, use 32-16-8 depending on the memory available
|
|
@@ -22,13 +23,15 @@ DATA:
|
|
|
22
23
|
num_workers: 8
|
|
23
24
|
pin_memory: true
|
|
24
25
|
valid:
|
|
25
|
-
|
|
26
|
+
video_path: ${DATA.data_dir}/valid
|
|
27
|
+
path: ${DATA.data_dir}/valid.json
|
|
26
28
|
dataloader:
|
|
27
29
|
batch_size: 8
|
|
28
30
|
num_workers: 8
|
|
29
31
|
shuffle: false
|
|
30
32
|
test:
|
|
31
|
-
|
|
33
|
+
video_path: ${DATA.data_dir}/test
|
|
34
|
+
path: ${DATA.data_dir}/test.json
|
|
32
35
|
dataloader:
|
|
33
36
|
batch_size: 8
|
|
34
37
|
num_workers: 8
|
|
@@ -9,7 +9,7 @@ TASK: classification
|
|
|
9
9
|
DATA:
|
|
10
10
|
dataset_name: sngar
|
|
11
11
|
data_modality: tracking_parquet
|
|
12
|
-
data_dir: /home/
|
|
12
|
+
data_dir: /home/giancos/datasets/OpenSportsLab/soccernetpro-classification-GAR/tracking-parquet
|
|
13
13
|
preload_data: false
|
|
14
14
|
train:
|
|
15
15
|
type: annotations_train.json
|
|
@@ -103,10 +103,10 @@ TRAIN:
|
|
|
103
103
|
type: CrossEntropyLoss
|
|
104
104
|
|
|
105
105
|
SYSTEM:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
106
|
+
log_dir: ./logs
|
|
107
|
+
save_dir: ./checkpoints_tracking
|
|
108
|
+
use_seed: true
|
|
109
|
+
seed: 42
|
|
110
|
+
GPU: 1
|
|
111
|
+
device: cuda # auto | cuda | cpu
|
|
112
|
+
gpu_id: 0
|
|
@@ -14,6 +14,7 @@ import gc
|
|
|
14
14
|
import json
|
|
15
15
|
import time
|
|
16
16
|
import logging
|
|
17
|
+
import multiprocessing as mp
|
|
17
18
|
|
|
18
19
|
import torch
|
|
19
20
|
import tqdm
|
|
@@ -865,36 +866,53 @@ class Trainer_Classification:
|
|
|
865
866
|
else:
|
|
866
867
|
val_sampler = None
|
|
867
868
|
|
|
868
|
-
|
|
869
|
-
|
|
869
|
+
train_num_workers = getattr(self.config.DATA.train.dataloader, "num_workers", 0)
|
|
870
|
+
train_pin_memory = getattr(self.config.DATA.train.dataloader, "pin_memory", self.device.type == "cuda")
|
|
871
|
+
train_mp_context = getattr(self.config.DATA.train.dataloader, "mp_context", None)
|
|
872
|
+
train_persistent_workers = getattr(self.config.DATA.train.dataloader, "persistent_workers", train_num_workers > 0)
|
|
873
|
+
train_prefetch_factor = getattr(self.config.DATA.train.dataloader, "prefetch_factor", 4 if train_num_workers > 0 else None)
|
|
874
|
+
|
|
875
|
+
if train_mp_context is not None:
|
|
876
|
+
train_mp_context = mp.get_context(train_mp_context)
|
|
870
877
|
|
|
871
878
|
train_loader = DataLoader(
|
|
872
879
|
train_dataset,
|
|
873
880
|
batch_size=self.config.DATA.train.dataloader.batch_size,
|
|
874
881
|
shuffle=(train_sampler is None and shuffle),
|
|
875
882
|
sampler=train_sampler,
|
|
876
|
-
num_workers=
|
|
877
|
-
pin_memory=
|
|
883
|
+
num_workers=train_num_workers,
|
|
884
|
+
pin_memory=train_pin_memory,
|
|
878
885
|
collate_fn=collate_fn,
|
|
879
886
|
worker_init_fn=seed_worker,
|
|
880
887
|
generator=g,
|
|
881
888
|
drop_last=True,
|
|
882
|
-
|
|
883
|
-
|
|
889
|
+
multiprocessing_context=train_mp_context,
|
|
890
|
+
persistent_workers=train_persistent_workers,
|
|
891
|
+
prefetch_factor=train_prefetch_factor,
|
|
884
892
|
)
|
|
885
893
|
|
|
894
|
+
valid_num_workers = getattr(self.config.DATA.valid.dataloader, "num_workers", 0)
|
|
895
|
+
valid_pin_memory = getattr(self.config.DATA.valid.dataloader, "pin_memory", self.device.type == "cuda")
|
|
896
|
+
valid_mp_context = getattr(self.config.DATA.valid.dataloader, "mp_context", None)
|
|
897
|
+
valid_persistent_workers = getattr(self.config.DATA.valid.dataloader, "persistent_workers", valid_num_workers > 0)
|
|
898
|
+
valid_prefetch_factor = getattr(self.config.DATA.valid.dataloader, "prefetch_factor", 4 if valid_num_workers > 0 else None)
|
|
899
|
+
|
|
900
|
+
if valid_mp_context is not None:
|
|
901
|
+
valid_mp_context = mp.get_context(valid_mp_context)
|
|
902
|
+
|
|
886
903
|
val_loader = DataLoader(
|
|
887
904
|
val_dataset,
|
|
888
905
|
batch_size=self.config.DATA.valid.dataloader.batch_size,
|
|
889
906
|
shuffle=False,
|
|
890
907
|
sampler=val_sampler,
|
|
891
|
-
num_workers=
|
|
892
|
-
pin_memory=
|
|
908
|
+
num_workers=valid_num_workers,
|
|
909
|
+
pin_memory=valid_pin_memory,
|
|
893
910
|
collate_fn=collate_fn,
|
|
894
911
|
worker_init_fn=seed_worker,
|
|
895
912
|
generator=g,
|
|
896
|
-
|
|
897
|
-
|
|
913
|
+
multiprocessing_context=valid_mp_context,
|
|
914
|
+
persistent_workers=valid_persistent_workers,
|
|
915
|
+
prefetch_factor=valid_prefetch_factor,
|
|
898
916
|
)
|
|
899
917
|
|
|
900
918
|
# select the modality-specific trainer.
|
{opensportslib-0.1.3 → opensportslib-0.1.3.dev2}/opensportslib/datasets/classification_dataset.py
RENAMED
|
@@ -73,7 +73,7 @@ class ClassificationDataset(Dataset):
|
|
|
73
73
|
self.config = config
|
|
74
74
|
self.split = split
|
|
75
75
|
self.exclude_labels = ["Unknown", "Dont know"]
|
|
76
|
-
self.
|
|
76
|
+
self.video_path = getattr(config.DATA, split).video_path #config.DATA.data_dir
|
|
77
77
|
self.processor = None
|
|
78
78
|
|
|
79
79
|
# view_type is optional; only MVFoul uses it as of now
|
|
@@ -197,7 +197,7 @@ class VideoDataset(ClassificationDataset):
|
|
|
197
197
|
"""
|
|
198
198
|
|
|
199
199
|
def __init__(self, config, annotations_path, processor, split="train"):
|
|
200
|
-
super().__init__(config, annotations_path, split)
|
|
200
|
+
super().__init__(config, annotations_path, processor, split=split)
|
|
201
201
|
|
|
202
202
|
self.processor = processor
|
|
203
203
|
self.view_type = getattr(config.DATA, "view_type", "single")
|
|
@@ -227,15 +227,13 @@ class VideoDataset(ClassificationDataset):
|
|
|
227
227
|
"""read a video file, temporally sub-sample, and apply transforms.
|
|
228
228
|
|
|
229
229
|
Args:
|
|
230
|
-
path: realtive path (under
|
|
230
|
+
path: realtive path (under video_path) to the video file.
|
|
231
231
|
|
|
232
232
|
Returns:
|
|
233
233
|
numpy.ndarray of shape (T, H, W, C).
|
|
234
234
|
"""
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
if full_path.endswith(".npy"):
|
|
238
|
-
frames = np.load(full_path).astype(np.float32) / 255.0
|
|
235
|
+
if path.endswith(".npy"):
|
|
236
|
+
frames = np.load(os.path.join(self.video_path, path)).astype(np.float32) / 255.0
|
|
239
237
|
if self.transform is not None:
|
|
240
238
|
frames = self.transform(frames)
|
|
241
239
|
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
@@ -243,7 +241,7 @@ class VideoDataset(ClassificationDataset):
|
|
|
243
241
|
frames = (frames - mean) / std
|
|
244
242
|
return frames
|
|
245
243
|
|
|
246
|
-
v = read_video(os.path.join(self.
|
|
244
|
+
v = read_video(os.path.join(self.video_path, path))
|
|
247
245
|
|
|
248
246
|
v = process_frames(
|
|
249
247
|
v,
|
|
@@ -473,7 +471,7 @@ class TrackingDataset(ClassificationDataset):
|
|
|
473
471
|
"""read a single parquet tracking clip.
|
|
474
472
|
|
|
475
473
|
Args:
|
|
476
|
-
path: Relative path (under ``
|
|
474
|
+
path: Relative path (under ``video_path``) to the parquet
|
|
477
475
|
file.
|
|
478
476
|
|
|
479
477
|
Returns:
|
|
@@ -481,7 +479,7 @@ class TrackingDataset(ClassificationDataset):
|
|
|
481
479
|
"""
|
|
482
480
|
import pandas as pd
|
|
483
481
|
|
|
484
|
-
full_path = os.path.join(self.
|
|
482
|
+
full_path = os.path.join(self.video_path, path)
|
|
485
483
|
return pd.read_parquet(full_path)
|
|
486
484
|
|
|
487
485
|
def __getitem__(self, idx):
|