opensportslib 0.0.1.dev16__tar.gz → 0.0.1.dev18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.0.1.dev16/opensportslib.egg-info → opensportslib-0.0.1.dev18}/PKG-INFO +1 -1
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/apis/localization.py +17 -6
- opensportslib-0.0.1.dev18/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +145 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/trainer/localization_trainer.py +43 -20
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/load_annotations.py +31 -1
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/wandb.py +10 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/datasets/localization_dataset.py +33 -29
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/metrics/localization_metric.py +4 -4
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/base/contextaware.py +12 -12
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/base/learnablepooling.py +26 -26
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/builder.py +4 -3
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/impl/gsm.py +10 -3
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/utils.py +27 -9
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18/opensportslib.egg-info}/PKG-INFO +1 -1
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib.egg-info/SOURCES.txt +1 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/pyproject.toml +1 -1
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/LICENSE +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/MANIFEST.in +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/README.md +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/apis/classification.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/config/classification.yaml +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/config/localization.yaml +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/config/sngar-frames.yaml +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/config/sngar-tracking.yaml +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/loss/builder.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/optimizer/builder.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/sampler/weighted_sampler.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/scheduler/builder.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/trainer/classification_trainer.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/config.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/default_args.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/video_processing.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/datasets/builder.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/datasets/classification_dataset.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/datasets/utils/tracking.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/backbones/builder.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/base/tracking.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/base/vars.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/base/video.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/base/video_mae.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.0.1.
|
|
3
|
+
Version: 0.0.1.dev18
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -101,7 +101,7 @@ class LocalizationAPI:
|
|
|
101
101
|
|
|
102
102
|
device = select_device(self.config.SYSTEM)
|
|
103
103
|
self.model = build_model(self.config, device=device)
|
|
104
|
-
print(self.model)
|
|
104
|
+
print(f"model: {self.model}")
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
# Datasets
|
|
@@ -155,7 +155,7 @@ class LocalizationAPI:
|
|
|
155
155
|
from opensportslib.core.trainer.localization_trainer import build_inferer, build_evaluator
|
|
156
156
|
from opensportslib.core.utils.config import select_device, resolve_config_omega, is_local_path
|
|
157
157
|
from opensportslib.core.utils.checkpoint import load_checkpoint, localization_remap
|
|
158
|
-
from opensportslib.core.utils.load_annotations import check_config, has_localization_events
|
|
158
|
+
from opensportslib.core.utils.load_annotations import check_config, has_localization_events, whether_infer_split
|
|
159
159
|
from opensportslib.core.utils.wandb import init_wandb
|
|
160
160
|
import time
|
|
161
161
|
|
|
@@ -163,6 +163,7 @@ class LocalizationAPI:
|
|
|
163
163
|
self.config.MODEL.multi_gpu = False
|
|
164
164
|
self.config = resolve_config_omega(self.config)
|
|
165
165
|
check_config(self.config, split="test")
|
|
166
|
+
self.config.infer_split = whether_infer_split(self.config.DATA.test)
|
|
166
167
|
init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
167
168
|
logging.info("Configuration:")
|
|
168
169
|
logging.info(self.config)
|
|
@@ -179,19 +180,29 @@ class LocalizationAPI:
|
|
|
179
180
|
logging.info("No predictions provided, running inference.")
|
|
180
181
|
device = select_device(self.config.SYSTEM)
|
|
181
182
|
self.model = build_model(self.config, device=device)
|
|
183
|
+
inner_model = getattr(self.model, "_model", None)
|
|
184
|
+
if inner_model is None:
|
|
185
|
+
inner_model = getattr(self.model, "model", self.model)
|
|
182
186
|
print("Model type:", type(self.model))
|
|
183
|
-
print("Torch model type:", type(
|
|
187
|
+
print("Torch model type:", type(inner_model))
|
|
184
188
|
# Load model
|
|
185
189
|
if pretrained:
|
|
186
190
|
#pretrained = expand(pretrained)
|
|
187
191
|
if is_local_path(pretrained):
|
|
188
192
|
self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(pretrained))
|
|
189
193
|
|
|
190
|
-
|
|
194
|
+
inner_model, _, _, epoch = load_checkpoint(model=inner_model,
|
|
191
195
|
path=pretrained,
|
|
192
196
|
device=device,
|
|
193
197
|
key_remap_fn=localization_remap)
|
|
194
198
|
|
|
199
|
+
if hasattr(self.model, "_model"):
|
|
200
|
+
self.model._model = inner_model
|
|
201
|
+
elif hasattr(self.model, "model"):
|
|
202
|
+
self.model.model = inner_model
|
|
203
|
+
else:
|
|
204
|
+
self.model = inner_model
|
|
205
|
+
|
|
195
206
|
# Datasets
|
|
196
207
|
# Test
|
|
197
208
|
data_obj_test = build_dataset(self.config, split="test")
|
|
@@ -206,7 +217,7 @@ class LocalizationAPI:
|
|
|
206
217
|
# # Inference
|
|
207
218
|
inferer = build_inferer(cfg=self.config.MODEL,
|
|
208
219
|
model=self.model)
|
|
209
|
-
json_gz_file = inferer.infer(cfg=self.config, data=dataset_Test)
|
|
220
|
+
json_gz_file = inferer.infer(cfg=self.config, data=dataset_Test, dataloader=test_loader)
|
|
210
221
|
|
|
211
222
|
#json_gz_file = self.config.DATA.test.results + ".recall.json.gz"
|
|
212
223
|
json_gz_file = predictions if predictions else json_gz_file
|
|
@@ -219,7 +230,7 @@ class LocalizationAPI:
|
|
|
219
230
|
evaluator = build_evaluator(cfg=self.config)
|
|
220
231
|
metrics = evaluator.evaluate(
|
|
221
232
|
cfg_testset=self.config.DATA.test,
|
|
222
|
-
json_gz_file=json_gz_file
|
|
233
|
+
json_gz_file=self.config.DATA.test.results if isinstance(json_gz_file, dict) else json_gz_file
|
|
223
234
|
)
|
|
224
235
|
else:
|
|
225
236
|
logging.info("No labels found in annotation file → skipping evaluation")
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
TASK: localization
|
|
2
|
+
|
|
3
|
+
dali: false
|
|
4
|
+
|
|
5
|
+
DATA:
|
|
6
|
+
dataset_name: SoccerNet
|
|
7
|
+
data_dir: /home/vorajv/opensportslib/SoccerNet/
|
|
8
|
+
classes:
|
|
9
|
+
- Penalty
|
|
10
|
+
- Kick-off
|
|
11
|
+
- Goal
|
|
12
|
+
- Substitution
|
|
13
|
+
- Offside
|
|
14
|
+
- Shots on target
|
|
15
|
+
- Shots off target
|
|
16
|
+
- Clearance
|
|
17
|
+
- Ball out of play
|
|
18
|
+
- Throw-in
|
|
19
|
+
- Foul
|
|
20
|
+
- Indirect free-kick
|
|
21
|
+
- Direct free-kick
|
|
22
|
+
- Corner
|
|
23
|
+
- Yellow card
|
|
24
|
+
- Red card
|
|
25
|
+
- Yellow->red card
|
|
26
|
+
|
|
27
|
+
epoch_num_frames: 500000
|
|
28
|
+
mixup: true
|
|
29
|
+
modality: rgb
|
|
30
|
+
crop_dim: -1
|
|
31
|
+
dilate_len: 0 # Dilate ground truth labels
|
|
32
|
+
clip_len: 100
|
|
33
|
+
input_fps: 25
|
|
34
|
+
extract_fps: 2
|
|
35
|
+
imagenet_mean: [0.485, 0.456, 0.406]
|
|
36
|
+
imagenet_std: [0.229, 0.224, 0.225]
|
|
37
|
+
target_height: 224
|
|
38
|
+
target_width: 398
|
|
39
|
+
|
|
40
|
+
train:
|
|
41
|
+
type: FeatureClipsfromJSON
|
|
42
|
+
classes: ${DATA.classes}
|
|
43
|
+
output_map: [data, label]
|
|
44
|
+
video_path: ${DATA.data_dir}
|
|
45
|
+
path: ${DATA.train.video_path}/annotations-2024-224p-train.json
|
|
46
|
+
framerate: 2
|
|
47
|
+
window_size: 20
|
|
48
|
+
dataloader:
|
|
49
|
+
batch_size: 256
|
|
50
|
+
shuffle: true
|
|
51
|
+
num_workers: 4
|
|
52
|
+
pin_memory: true
|
|
53
|
+
|
|
54
|
+
valid:
|
|
55
|
+
type: FeatureClipsfromJSON
|
|
56
|
+
classes: ${DATA.classes}
|
|
57
|
+
output_map: [data, label]
|
|
58
|
+
video_path: ${DATA.data_dir}
|
|
59
|
+
path: ${DATA.valid.video_path}/annotations-2024-224p-valid.json
|
|
60
|
+
framerate: 2
|
|
61
|
+
window_size: 20
|
|
62
|
+
dataloader:
|
|
63
|
+
batch_size: 256
|
|
64
|
+
shuffle: true
|
|
65
|
+
num_workers: 4
|
|
66
|
+
pin_memory: true
|
|
67
|
+
|
|
68
|
+
test:
|
|
69
|
+
type: FeatureVideosfromJSON
|
|
70
|
+
classes: ${DATA.classes}
|
|
71
|
+
output_map: [data, label]
|
|
72
|
+
video_path: ${DATA.data_dir}
|
|
73
|
+
path: ${DATA.test.video_path}/annotations-2024-224p-test.json
|
|
74
|
+
results: results_spotting_test_netvlad++_resnetpca512
|
|
75
|
+
framerate: 2
|
|
76
|
+
window_size: 20
|
|
77
|
+
metric: tight
|
|
78
|
+
dataloader:
|
|
79
|
+
batch_size: 1
|
|
80
|
+
shuffle: false
|
|
81
|
+
num_workers: 1
|
|
82
|
+
pin_memory: true
|
|
83
|
+
|
|
84
|
+
MODEL:
|
|
85
|
+
type: LearnablePooling
|
|
86
|
+
runner:
|
|
87
|
+
type: runner_JSON
|
|
88
|
+
backbone:
|
|
89
|
+
type: PreExtactedFeatures
|
|
90
|
+
encoder: ResNET_TF2_PCA512
|
|
91
|
+
feature_dim: 512
|
|
92
|
+
output_dim: 512
|
|
93
|
+
framerate: 2
|
|
94
|
+
window_size: 20
|
|
95
|
+
neck:
|
|
96
|
+
type: NetVLAD++
|
|
97
|
+
input_dim: 512
|
|
98
|
+
output_dim: 32768 # 512 clusters * 64 vocab size
|
|
99
|
+
vocab_size: 64
|
|
100
|
+
head:
|
|
101
|
+
type: LinearLayer
|
|
102
|
+
input_dim: 32768
|
|
103
|
+
num_classes: 17
|
|
104
|
+
post_proc:
|
|
105
|
+
type: NMS
|
|
106
|
+
NMS_window: 30
|
|
107
|
+
NMS_threshold: 0.0
|
|
108
|
+
load_weights: null
|
|
109
|
+
|
|
110
|
+
TRAIN:
|
|
111
|
+
type: trainer_pooling
|
|
112
|
+
max_epochs: 200
|
|
113
|
+
evaluation_frequency: 200
|
|
114
|
+
framerate: 2
|
|
115
|
+
batch_size: 256
|
|
116
|
+
|
|
117
|
+
criterion:
|
|
118
|
+
type: NLLLoss
|
|
119
|
+
|
|
120
|
+
optimizer:
|
|
121
|
+
type: Adam
|
|
122
|
+
lr: 0.001
|
|
123
|
+
betas: [0.9, 0.999]
|
|
124
|
+
eps: 1e-08
|
|
125
|
+
weight_decay: 0
|
|
126
|
+
amsgrad: false
|
|
127
|
+
|
|
128
|
+
scheduler:
|
|
129
|
+
type: ReduceLROnPlateau
|
|
130
|
+
mode: min
|
|
131
|
+
factor: 1e-03
|
|
132
|
+
min_lr: 1e-06
|
|
133
|
+
patience: 10
|
|
134
|
+
verbose: true
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
SYSTEM:
|
|
138
|
+
log_dir: ./logs
|
|
139
|
+
save_dir: ./checkpoints
|
|
140
|
+
work_dir: ${SYSTEM.save_dir}
|
|
141
|
+
seed: 42
|
|
142
|
+
GPU: 4 # number of gpus to use
|
|
143
|
+
device: cuda # auto | cuda | cpu
|
|
144
|
+
gpu_id: 0 # device id for single gpu training
|
|
145
|
+
|
|
@@ -160,15 +160,23 @@ class Trainer_pl(Trainer):
|
|
|
160
160
|
def __init__(self, cfg, work_dir):
|
|
161
161
|
from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
|
|
162
162
|
import pytorch_lightning as pl
|
|
163
|
+
from pytorch_lightning.loggers import WandbLogger
|
|
164
|
+
import wandb
|
|
165
|
+
|
|
166
|
+
wandb_logger = None
|
|
167
|
+
if wandb.run is not None: # means init_wandb already ran
|
|
168
|
+
wandb_logger = WandbLogger(experiment=wandb.run)
|
|
163
169
|
|
|
164
170
|
self.work_dir = work_dir
|
|
165
171
|
call = MyCallback()
|
|
166
172
|
self.trainer = pl.Trainer(
|
|
167
|
-
|
|
168
|
-
|
|
173
|
+
logger=wandb_logger,
|
|
174
|
+
max_epochs=cfg.TRAIN.max_epochs,
|
|
175
|
+
devices=cfg.SYSTEM.GPU,
|
|
169
176
|
callbacks=[call, CustomProgressBar(refresh_rate=1)],
|
|
170
177
|
num_sanity_val_steps=0,
|
|
171
178
|
)
|
|
179
|
+
self.best_checkpoint_path = None
|
|
172
180
|
|
|
173
181
|
def train(self, **kwargs):
|
|
174
182
|
self.trainer.fit(**kwargs)
|
|
@@ -177,10 +185,12 @@ class Trainer_pl(Trainer):
|
|
|
177
185
|
|
|
178
186
|
logging.info("Done training")
|
|
179
187
|
logging.info("Best epoch: {}".format(best_model.get("epoch")))
|
|
180
|
-
|
|
188
|
+
best_path = os.path.join(self.work_dir, "model.pth.tar")
|
|
189
|
+
self.best_checkpoint_path = best_path
|
|
190
|
+
torch.save(best_model, best_path)
|
|
181
191
|
|
|
182
192
|
logging.info("Model saved")
|
|
183
|
-
logging.info(
|
|
193
|
+
logging.info(best_path)
|
|
184
194
|
|
|
185
195
|
|
|
186
196
|
class Trainer_e2e(Trainer):
|
|
@@ -496,24 +506,25 @@ class Inferer:
|
|
|
496
506
|
self.model = model
|
|
497
507
|
self.infer_Spotting=infer_Spotting
|
|
498
508
|
|
|
499
|
-
def infer(self, cfg, data):
|
|
509
|
+
def infer(self, cfg, data, dataloader=None):
|
|
500
510
|
"""Infer actions from data.
|
|
501
511
|
|
|
502
512
|
Args:
|
|
503
513
|
data : The data from which we will infer.
|
|
514
|
+
dataloader : The dataloader for the test data.
|
|
504
515
|
|
|
505
516
|
Returns:
|
|
506
517
|
Dict containing predictions
|
|
507
518
|
"""
|
|
508
519
|
if self.infer_Spotting=="infer_JSON":
|
|
509
|
-
return self.infer_JSON(cfg, self.model, data)
|
|
520
|
+
return self.infer_JSON(cfg, self.model, data, dataloader)
|
|
510
521
|
elif self.infer_Spotting=="infer_SN":
|
|
511
|
-
return self.infer_SN(cfg, self.model, data)
|
|
522
|
+
return self.infer_SN(cfg, self.model, data, dataloader)
|
|
512
523
|
elif self.infer_Spotting=="infer_E2E":
|
|
513
|
-
return self.infer_E2E(cfg, self.model, data)
|
|
524
|
+
return self.infer_E2E(cfg, self.model, data, dataloader)
|
|
514
525
|
|
|
515
526
|
|
|
516
|
-
def infer_common(self, cfg, model, data):
|
|
527
|
+
def infer_common(self, cfg, model, data, dataloader=None):
|
|
517
528
|
"""Infer actions from data using a given model.
|
|
518
529
|
|
|
519
530
|
Args:
|
|
@@ -525,10 +536,21 @@ class Inferer:
|
|
|
525
536
|
Dict containing predictions
|
|
526
537
|
"""
|
|
527
538
|
# Run Inference on Dataset
|
|
528
|
-
|
|
539
|
+
from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
|
|
540
|
+
import pytorch_lightning as pl
|
|
541
|
+
|
|
542
|
+
if cfg.SYSTEM.work_dir is not None and dataloader is not None:
|
|
543
|
+
|
|
544
|
+
evaluator = pl.Trainer(
|
|
545
|
+
callbacks=[CustomProgressBar()],
|
|
546
|
+
devices=cfg.SYSTEM.GPU,
|
|
547
|
+
num_sanity_val_steps=0,
|
|
548
|
+
)
|
|
549
|
+
evaluator.predict(model, dataloader)
|
|
550
|
+
return model.json_data
|
|
529
551
|
|
|
530
552
|
|
|
531
|
-
def infer_JSON(self, cfg, model, data):
|
|
553
|
+
def infer_JSON(self, cfg, model, data, dataloader=None):
|
|
532
554
|
"""Infer actions from data using a given model for NetVlad/CALF methods
|
|
533
555
|
|
|
534
556
|
Args:
|
|
@@ -539,10 +561,10 @@ class Inferer:
|
|
|
539
561
|
Returns:
|
|
540
562
|
Dict containing predictions
|
|
541
563
|
"""
|
|
542
|
-
return self.infer_common(cfg, model, data)
|
|
564
|
+
return self.infer_common(cfg, model, data, dataloader)
|
|
543
565
|
|
|
544
566
|
|
|
545
|
-
def infer_SN(self, cfg, model, data):
|
|
567
|
+
def infer_SN(self, cfg, model, data, dataloader=None):
|
|
546
568
|
"""Infer actions from data using a given model for the SNV2 data
|
|
547
569
|
|
|
548
570
|
Args:
|
|
@@ -553,10 +575,10 @@ class Inferer:
|
|
|
553
575
|
Returns:
|
|
554
576
|
Dict containing predictions
|
|
555
577
|
"""
|
|
556
|
-
return self.infer_common(cfg, model, data)
|
|
578
|
+
return self.infer_common(cfg, model, data, dataloader)
|
|
557
579
|
|
|
558
580
|
|
|
559
|
-
def infer_E2E(self, cfg, model, data):
|
|
581
|
+
def infer_E2E(self, cfg, model, data, dataloader=None):
|
|
560
582
|
"""Infer actions from data using a given model for the e2espot method.
|
|
561
583
|
|
|
562
584
|
Args:
|
|
@@ -735,7 +757,6 @@ class Evaluator:
|
|
|
735
757
|
|
|
736
758
|
|
|
737
759
|
def evaluate_common_JSON(self, cfg, results, metric):
|
|
738
|
-
|
|
739
760
|
if cfg.path is None:
|
|
740
761
|
return
|
|
741
762
|
|
|
@@ -756,6 +777,7 @@ class Evaluator:
|
|
|
756
777
|
|
|
757
778
|
# detect v2 prediction
|
|
758
779
|
pred_is_v2 = isinstance(pred_data, dict) and pred_data is not None and "data" in pred_data
|
|
780
|
+
print("PRED V2 :", pred_is_v2)
|
|
759
781
|
# --------------------------------------------------
|
|
760
782
|
# CLASSES
|
|
761
783
|
# --------------------------------------------------
|
|
@@ -803,7 +825,7 @@ class Evaluator:
|
|
|
803
825
|
video_path = game["inputs"][0]["path"]
|
|
804
826
|
labels = [{"label": e.get("label"),
|
|
805
827
|
"gameTime": e.get("gameTime"),
|
|
806
|
-
"position": int(e.get("position_ms")),
|
|
828
|
+
"position": int(e.get("position_ms", e.get("position"))),
|
|
807
829
|
} for e in game.get("events", [])]
|
|
808
830
|
else:
|
|
809
831
|
video_path = game["path"]
|
|
@@ -825,7 +847,7 @@ class Evaluator:
|
|
|
825
847
|
"label": e.get("label"),
|
|
826
848
|
"gameTime": e.get("gameTime"),
|
|
827
849
|
"confidence": e.get("confidence"),
|
|
828
|
-
"position": int(e.get("position_ms")),
|
|
850
|
+
"position": int(e.get("position_ms", e.get("position"))),
|
|
829
851
|
"frame": e.get("frame")
|
|
830
852
|
}
|
|
831
853
|
for e in item.get("events", [])
|
|
@@ -859,7 +881,7 @@ class Evaluator:
|
|
|
859
881
|
"label": e.get("label"),
|
|
860
882
|
"gameTime": e.get("gameTime"),
|
|
861
883
|
"confidence": e.get("confidence"),
|
|
862
|
-
"position": int(e.get("position_ms")),
|
|
884
|
+
"position": int(e.get("position_ms", e.get("position"))),
|
|
863
885
|
"frame": e.get("frame")
|
|
864
886
|
}
|
|
865
887
|
for e in item.get("events", [])
|
|
@@ -997,7 +1019,8 @@ class Evaluator:
|
|
|
997
1019
|
Returns
|
|
998
1020
|
The different mAPs computed.
|
|
999
1021
|
"""
|
|
1000
|
-
|
|
1022
|
+
from SoccerNet.Evaluation.utils import INVERSE_EVENT_DICTIONARY_V2
|
|
1023
|
+
from SoccerNet.Evaluation.ActionSpotting import evaluate
|
|
1001
1024
|
# challenge sets to be tested on EvalAI
|
|
1002
1025
|
if "challenge" in cfg.split:
|
|
1003
1026
|
print("Visit eval.ai to evaluate performances on Challenge set")
|
{opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/core/utils/load_annotations.py
RENAMED
|
@@ -536,4 +536,34 @@ def check_config(cfg, split="train"):
|
|
|
536
536
|
classes = cfg.DATA.classes
|
|
537
537
|
|
|
538
538
|
#print(classes)
|
|
539
|
-
cfg.DATA.classes = load_classes(classes)
|
|
539
|
+
cfg.DATA.classes = load_classes(classes)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def whether_infer_split(cfg):
|
|
543
|
+
"""Given a config dict, check whether we want to infer a split or a single element (can be a game, video or feature file)/
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
cfg (dict): Config dict.
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
bool : True if we infer split, false otherwise. Raises an error if the input is not expected.
|
|
550
|
+
"""
|
|
551
|
+
if cfg.type == "SoccerNetGames" or cfg.type == "SoccerNetClipsTestingCALF":
|
|
552
|
+
if cfg.split == None:
|
|
553
|
+
return False
|
|
554
|
+
else:
|
|
555
|
+
return True
|
|
556
|
+
elif (
|
|
557
|
+
cfg.type == "FeatureVideosfromJSON" or cfg.type == "FeatureVideosChunksfromJson"
|
|
558
|
+
):
|
|
559
|
+
if cfg.path.endswith(".json"):
|
|
560
|
+
return True
|
|
561
|
+
else:
|
|
562
|
+
return False
|
|
563
|
+
elif cfg.type == "VideoGameWithOpencvVideo" or cfg.type == "VideoGameWithDaliVideo":
|
|
564
|
+
if cfg.path.endswith(".json"):
|
|
565
|
+
return True
|
|
566
|
+
else:
|
|
567
|
+
return False
|
|
568
|
+
else:
|
|
569
|
+
raise ValueError(f"Unknown dataset type {cfg.type}")
|
|
@@ -2,6 +2,7 @@ import wandb
|
|
|
2
2
|
import matplotlib.pyplot as plt
|
|
3
3
|
import numpy as np
|
|
4
4
|
import logging
|
|
5
|
+
import os
|
|
5
6
|
|
|
6
7
|
def init_wandb(cfg, run_id, use_wandb=False):
|
|
7
8
|
"""
|
|
@@ -24,6 +25,15 @@ def init_wandb(cfg, run_id, use_wandb=False):
|
|
|
24
25
|
logging.warning("wandb not installed. Install with `pip install wandb`.")
|
|
25
26
|
return None
|
|
26
27
|
|
|
28
|
+
# Prevent multiple processes from initializing wandb
|
|
29
|
+
rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0)))
|
|
30
|
+
if rank != 0:
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
# Prevent re-initialization
|
|
34
|
+
if wandb.run is not None:
|
|
35
|
+
return wandb
|
|
36
|
+
|
|
27
37
|
if getattr(cfg.DATA, "data_modality", None):
|
|
28
38
|
run_name = f"{cfg.MODEL.backbone.type}_{cfg.DATA.data_modality}"
|
|
29
39
|
else:
|
|
@@ -29,7 +29,6 @@ try:
|
|
|
29
29
|
DALI_AVAILABLE = True
|
|
30
30
|
|
|
31
31
|
except ImportError:
|
|
32
|
-
print("NO DALI")
|
|
33
32
|
DALI_AVAILABLE = False
|
|
34
33
|
# Optional: placeholders (prevents NameError)
|
|
35
34
|
pipeline_def = None
|
|
@@ -67,7 +66,6 @@ class LocalizationDataset(Dataset):
|
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
def building_dataset(self, cfg, gpu=None, default_args=None):
|
|
70
|
-
print(cfg)
|
|
71
69
|
if cfg.type == "SoccerNetClips" or cfg.type == "SoccerNetGames":
|
|
72
70
|
if cfg.split == None:
|
|
73
71
|
dataset = SoccerNetGameClips(
|
|
@@ -1576,7 +1574,8 @@ class FeaturefromJson(Dataset):
|
|
|
1576
1574
|
with open(single_path) as f:
|
|
1577
1575
|
tmp = json.load(f)
|
|
1578
1576
|
self.data_json.append(tmp)
|
|
1579
|
-
|
|
1577
|
+
for task_name, task_data in tmp["labels"].items():
|
|
1578
|
+
self.classes.append(task_data.get("labels", {}))
|
|
1580
1579
|
assert all(x == self.classes[0] for x in self.classes) == True
|
|
1581
1580
|
self.classes = self.classes[0]
|
|
1582
1581
|
|
|
@@ -1588,15 +1587,19 @@ class FeaturefromJson(Dataset):
|
|
|
1588
1587
|
with open(path) as f:
|
|
1589
1588
|
tmp = json.load(f)
|
|
1590
1589
|
self.data_json = [tmp]
|
|
1591
|
-
|
|
1590
|
+
|
|
1591
|
+
for task_name, task_data in tmp["labels"].items():
|
|
1592
|
+
self.classes = task_data.get("labels", {})
|
|
1592
1593
|
else:
|
|
1593
1594
|
self.is_json = False
|
|
1594
1595
|
self.data_json = [
|
|
1595
1596
|
{
|
|
1596
|
-
"
|
|
1597
|
+
"data": [
|
|
1597
1598
|
{
|
|
1598
|
-
"
|
|
1599
|
-
|
|
1599
|
+
"inputs": [{
|
|
1600
|
+
"path": path
|
|
1601
|
+
}],
|
|
1602
|
+
"events": [],
|
|
1600
1603
|
}
|
|
1601
1604
|
]
|
|
1602
1605
|
}
|
|
@@ -1608,6 +1611,7 @@ class FeaturefromJson(Dataset):
|
|
|
1608
1611
|
self.classes = load_text(self.classes)
|
|
1609
1612
|
|
|
1610
1613
|
self.num_classes = len(self.classes)
|
|
1614
|
+
print(self.num_classes)
|
|
1611
1615
|
self.event_dictionary = {cls: i_cls for i_cls, cls in enumerate(self.classes)}
|
|
1612
1616
|
self.inverse_event_dictionary = {
|
|
1613
1617
|
i_cls: cls for i_cls, cls in enumerate(self.classes)
|
|
@@ -1633,8 +1637,8 @@ class FeaturefromJson(Dataset):
|
|
|
1633
1637
|
# time = annotation["gameTime"]
|
|
1634
1638
|
event = annotation["label"]
|
|
1635
1639
|
|
|
1636
|
-
if "
|
|
1637
|
-
frame = int(self.framerate * (int(annotation["
|
|
1640
|
+
if "position_ms" in annotation.keys():
|
|
1641
|
+
frame = int(self.framerate * (int(annotation["position_ms"]) / 1000))
|
|
1638
1642
|
else:
|
|
1639
1643
|
time = annotation["gameTime"]
|
|
1640
1644
|
|
|
@@ -1685,11 +1689,11 @@ class FeatureClipsfromJSON(FeaturefromJson):
|
|
|
1685
1689
|
else:
|
|
1686
1690
|
logging.info("Processing " + path)
|
|
1687
1691
|
# loop over videos
|
|
1688
|
-
for video in tqdm.tqdm(single_data_json["
|
|
1692
|
+
for video in tqdm.tqdm(single_data_json["data"]):
|
|
1689
1693
|
# for video in tqdm(self.data_json["videos"]):
|
|
1690
1694
|
# Load features
|
|
1691
1695
|
features = np.load(
|
|
1692
|
-
os.path.join(self.features_dir[i], video["path"])
|
|
1696
|
+
os.path.join(self.features_dir[i], video["inputs"][0]["path"])
|
|
1693
1697
|
)
|
|
1694
1698
|
features = features.reshape(-1, features.shape[-1])
|
|
1695
1699
|
|
|
@@ -1705,7 +1709,7 @@ class FeatureClipsfromJSON(FeaturefromJson):
|
|
|
1705
1709
|
labels[:, 0] = 1 # those are BG classes
|
|
1706
1710
|
|
|
1707
1711
|
# loop annotation for that video
|
|
1708
|
-
for annotation in video
|
|
1712
|
+
for annotation in video.get("events", []):
|
|
1709
1713
|
|
|
1710
1714
|
label, frame, cont = self.annotation(annotation)
|
|
1711
1715
|
|
|
@@ -1743,20 +1747,20 @@ class FeatureClipsfromJSON(FeaturefromJson):
|
|
|
1743
1747
|
if self.train:
|
|
1744
1748
|
return self.features_clips[index, :, :], self.labels_clips[index, :]
|
|
1745
1749
|
else:
|
|
1746
|
-
video = self.data_json[0]["
|
|
1747
|
-
|
|
1750
|
+
video = self.data_json[0]["data"][index]
|
|
1751
|
+
video_path = video["inputs"][0]["path"]
|
|
1748
1752
|
# Load features
|
|
1749
1753
|
if self.is_json:
|
|
1750
|
-
features = np.load(os.path.join(self.features_dir[0],
|
|
1754
|
+
features = np.load(os.path.join(self.features_dir[0], video_path))
|
|
1751
1755
|
else:
|
|
1752
|
-
features = np.load(os.path.join(
|
|
1756
|
+
features = np.load(os.path.join(video_path))
|
|
1753
1757
|
features = features.reshape(-1, features.shape[-1])
|
|
1754
1758
|
|
|
1755
1759
|
# Load labels
|
|
1756
1760
|
labels = np.zeros((features.shape[0], self.num_classes))
|
|
1757
1761
|
|
|
1758
|
-
if "
|
|
1759
|
-
for annotation in video
|
|
1762
|
+
if "events" in video.keys():
|
|
1763
|
+
for annotation in video.get("events", []):
|
|
1760
1764
|
|
|
1761
1765
|
label, frame, cont = self.annotation(annotation)
|
|
1762
1766
|
|
|
@@ -1773,13 +1777,13 @@ class FeatureClipsfromJSON(FeaturefromJson):
|
|
|
1773
1777
|
clip_length=self.window_size_frame,
|
|
1774
1778
|
)
|
|
1775
1779
|
|
|
1776
|
-
return
|
|
1780
|
+
return video_path, features, labels
|
|
1777
1781
|
|
|
1778
1782
|
def __len__(self):
|
|
1779
1783
|
if self.train:
|
|
1780
1784
|
return len(self.features_clips)
|
|
1781
1785
|
else:
|
|
1782
|
-
return len(self.data_json[0]["
|
|
1786
|
+
return len(self.data_json[0]["data"])
|
|
1783
1787
|
|
|
1784
1788
|
|
|
1785
1789
|
class FeatureClipChunksfromJson(FeaturefromJson):
|
|
@@ -1844,18 +1848,18 @@ class FeatureClipChunksfromJson(FeaturefromJson):
|
|
|
1844
1848
|
else:
|
|
1845
1849
|
logging.info("Processing " + path)
|
|
1846
1850
|
# loop over videos
|
|
1847
|
-
for video in tqdm.tqdm(single_data_json["
|
|
1851
|
+
for video in tqdm.tqdm(single_data_json["data"]):
|
|
1848
1852
|
# for video in tqdm(self.data_json["videos"]):
|
|
1849
1853
|
# Load features
|
|
1850
1854
|
features = np.load(
|
|
1851
|
-
os.path.join(self.features_dir[i], video["path"])
|
|
1855
|
+
os.path.join(self.features_dir[i], video["inputs"][0]["path"])
|
|
1852
1856
|
)
|
|
1853
1857
|
|
|
1854
1858
|
# Load labels
|
|
1855
1859
|
labels = np.zeros((features.shape[0], self.num_classes))
|
|
1856
1860
|
|
|
1857
1861
|
# loop annotation for that video
|
|
1858
|
-
for annotation in video
|
|
1862
|
+
for annotation in video.get("events", []):
|
|
1859
1863
|
|
|
1860
1864
|
label, frame, cont = self.annotation(annotation)
|
|
1861
1865
|
|
|
@@ -1938,19 +1942,19 @@ class FeatureClipChunksfromJson(FeaturefromJson):
|
|
|
1938
1942
|
torch.from_numpy(clip_targets),
|
|
1939
1943
|
)
|
|
1940
1944
|
else:
|
|
1941
|
-
video = self.data_json[0]["
|
|
1942
|
-
|
|
1945
|
+
video = self.data_json[0]["data"][index]
|
|
1946
|
+
video_path = video["inputs"][0]["path"]
|
|
1943
1947
|
# Load features
|
|
1944
1948
|
if self.is_json:
|
|
1945
|
-
features = np.load(os.path.join(self.features_dir[0],
|
|
1949
|
+
features = np.load(os.path.join(self.features_dir[0], video_path))
|
|
1946
1950
|
else:
|
|
1947
|
-
features = np.load(os.path.join(
|
|
1951
|
+
features = np.load(os.path.join(video_path))
|
|
1948
1952
|
|
|
1949
1953
|
# Load labels
|
|
1950
1954
|
labels = np.zeros((features.shape[0], self.num_classes))
|
|
1951
1955
|
|
|
1952
1956
|
if "annotations" in video.keys():
|
|
1953
|
-
for annotation in video
|
|
1957
|
+
for annotation in video.get("events",[]):
|
|
1954
1958
|
|
|
1955
1959
|
label, frame, cont = self.annotation(annotation)
|
|
1956
1960
|
if cont:
|
|
@@ -1977,7 +1981,7 @@ class FeatureClipChunksfromJson(FeaturefromJson):
|
|
|
1977
1981
|
if self.train:
|
|
1978
1982
|
return self.chunks_per_epoch
|
|
1979
1983
|
else:
|
|
1980
|
-
return len(self.data_json[0]["
|
|
1984
|
+
return len(self.data_json[0]["data"])
|
|
1981
1985
|
|
|
1982
1986
|
|
|
1983
1987
|
|
{opensportslib-0.0.1.dev16 → opensportslib-0.0.1.dev18}/opensportslib/metrics/localization_metric.py
RENAMED
|
@@ -880,9 +880,9 @@ def label2vector(
|
|
|
880
880
|
|
|
881
881
|
else:
|
|
882
882
|
time = annotation["gameTime"]
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
883
|
+
if time is not None:
|
|
884
|
+
minutes = int(time[-5:-3])
|
|
885
|
+
seconds = int(time[-2::])
|
|
886
886
|
# annotation at millisecond precision
|
|
887
887
|
if "position" in annotation:
|
|
888
888
|
frame = int(framerate * (int(annotation["position"]) / 1000))
|
|
@@ -931,7 +931,7 @@ def predictions2vector(
|
|
|
931
931
|
|
|
932
932
|
event = annotation["label"]
|
|
933
933
|
|
|
934
|
-
if "frame" in annotation:
|
|
934
|
+
if "frame" in annotation and annotation["frame"] is not None:
|
|
935
935
|
frame = int(annotation["frame"])
|
|
936
936
|
else:
|
|
937
937
|
time = int(annotation["position"])
|