opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,135 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+ import math
31
+ import torch
32
+ import torch.nn as nn
33
+ import torchvision
34
+ import timm
35
+
36
+ from .impl.tsm import TemporalShift
37
+ from .impl.gsm import _GSM
38
+
39
+
40
+ class GatedShift(nn.Module):
41
+ def __init__(self, net, n_segment, n_div):
42
+ super(GatedShift, self).__init__()
43
+
44
+ if isinstance(net, torchvision.models.resnet.BasicBlock):
45
+ channels = net.conv1.in_channels
46
+ elif isinstance(net, torchvision.ops.misc.ConvNormActivation):
47
+ channels = net[0].in_channels
48
+ elif isinstance(net, timm.layers.conv_bn_act.ConvBnAct):
49
+ channels = net.conv.in_channels
50
+ elif isinstance(net, nn.Conv2d):
51
+ channels = net.in_channels
52
+ else:
53
+ raise NotImplementedError(type(net))
54
+
55
+ self.fold_dim = math.ceil(channels // n_div / 4) * 4
56
+ self.gsm = _GSM(self.fold_dim, n_segment)
57
+ self.net = net
58
+ self.n_segment = n_segment
59
+ print("=> Using GSM, fold dim: {} / {}".format(self.fold_dim, channels))
60
+
61
+ def forward(self, x):
62
+ y = torch.zeros_like(x)
63
+ y[:, : self.fold_dim, :, :] = self.gsm(x[:, : self.fold_dim, :, :])
64
+ y[:, self.fold_dim :, :, :] = x[:, self.fold_dim :, :, :]
65
+ return self.net(y)
66
+
67
+
68
+ # Adapted from: https://github.com/mit-han-lab/temporal-shift-module/blob/master/ops/temporal_shift.py
69
+ def make_temporal_shift(net, clip_len, is_gsm=False):
70
+
71
+ def _build_shift(net):
72
+ if is_gsm:
73
+ return GatedShift(net, n_segment=clip_len, n_div=4)
74
+ else:
75
+ return TemporalShift(net, n_segment=clip_len, n_div=8)
76
+
77
+ if isinstance(net, torchvision.models.ResNet):
78
+ n_round = 1
79
+ if len(list(net.layer3.children())) >= 23:
80
+ n_round = 2
81
+ print("=> Using n_round {} to insert temporal shift".format(n_round))
82
+
83
+ def make_block_temporal(stage):
84
+ blocks = list(stage.children())
85
+ print("=> Processing stage with {} blocks residual".format(len(blocks)))
86
+ for i, b in enumerate(blocks):
87
+ if i % n_round == 0:
88
+ blocks[i].conv1 = _build_shift(b.conv1)
89
+ return nn.Sequential(*blocks)
90
+
91
+ net.layer1 = make_block_temporal(net.layer1)
92
+ net.layer2 = make_block_temporal(net.layer2)
93
+ net.layer3 = make_block_temporal(net.layer3)
94
+ net.layer4 = make_block_temporal(net.layer4)
95
+
96
+ elif isinstance(net, timm.models.regnet.RegNet):
97
+ n_round = 1
98
+ # if len(list(net.trunk_output.block3.children())) >= 23:
99
+ # n_round = 2
100
+ # print('=> Using n_round {} to insert temporal shift'.format(n_round))
101
+
102
+ def make_block_temporal(stage):
103
+ blocks = list(stage.children())
104
+ print("=> Processing stage with {} blocks residual".format(len(blocks)))
105
+ for i, b in enumerate(blocks):
106
+ if i % n_round == 0:
107
+ blocks[i].conv1 = _build_shift(b.conv1)
108
+
109
+ make_block_temporal(net.s1)
110
+ make_block_temporal(net.s2)
111
+ make_block_temporal(net.s3)
112
+ make_block_temporal(net.s4)
113
+
114
+ elif isinstance(net, timm.models.convnext.ConvNeXt):
115
+ n_round = 1
116
+ # if len(list(net.stages[2].children())) >= 23:
117
+ # n_round = 2
118
+ # print('=> Using n_round {} to insert temporal shift'.format(n_round))
119
+
120
+ def make_block_temporal(stage):
121
+ blocks = list(stage.blocks)
122
+ print("=> Processing stage with {} blocks residual".format(len(blocks)))
123
+
124
+ for i, b in enumerate(blocks):
125
+ if i % n_round == 0:
126
+ blocks[i].conv_dw = _build_shift(b.conv_dw)
127
+ return nn.Sequential(*blocks)
128
+
129
+ make_block_temporal(net.stages[0])
130
+ make_block_temporal(net.stages[1])
131
+ make_block_temporal(net.stages[2])
132
+ make_block_temporal(net.stages[3])
133
+
134
+ else:
135
+ raise NotImplementedError("Unsupported architecture")
@@ -0,0 +1,276 @@
1
+ import json
2
+ import os
3
+ import logging
4
+ import zipfile
5
+ import numpy as np
6
+
7
+ from SoccerNet.Evaluation.utils import (
8
+ INVERSE_EVENT_DICTIONARY_V2,
9
+ INVERSE_EVENT_DICTIONARY_V1,
10
+ )
11
+
12
+
13
+ def check_if_should_predict(folder_name, work_dir, overwrite):
14
+ """Check if zip file with specified name already exists, if it exists and overwrite is false, it should not predict
15
+ Args:
16
+ folder_name (string): Name of the folder and of the file zip.
17
+ work_dir (string): folder where the zip is located
18
+ overwrite (bool).
19
+ """
20
+ # Create folder name and zip file name
21
+ output_folder = folder_name
22
+ # output_folder=f"results_spotting_{'_'.join(split)}"
23
+ output_results = os.path.join(work_dir, f"{output_folder}.zip")
24
+ stop_predict = False
25
+ # Prevent overwriting existing results
26
+ if os.path.exists(output_results) and not overwrite:
27
+ logging.warning(
28
+ "Results already exists in zip format. Use [overwrite=True] to overwrite the previous results.The inference will not run over the previous results."
29
+ )
30
+ stop_predict = True
31
+ # return output_results
32
+ return output_folder, output_results, stop_predict
33
+
34
+
35
+ def timestamp(model, feat, BS):
36
+ """Compute the timestamps for features using a model and a batch size."""
37
+ timestamp_long = []
38
+ for b in range(int(np.ceil(len(feat) / BS))):
39
+ start_frame = BS * b
40
+ end_frame = BS * (b + 1) if BS * (b + 1) < len(feat) else len(feat)
41
+ feat_tmp = feat[start_frame:end_frame].cuda()
42
+ output = model(feat_tmp).cpu().detach().numpy()
43
+ timestamp_long.append(output)
44
+ return np.concatenate(timestamp_long)
45
+
46
+
47
+ def get_spot_from_NMS(Input, window=60, thresh=0.0):
48
+ detections_tmp = np.copy(Input)
49
+ indexes = []
50
+ MaxValues = []
51
+ while(np.max(detections_tmp) >= thresh):
52
+
53
+ # Get the max remaining index and value
54
+ max_value = np.max(detections_tmp)
55
+ max_index = np.argmax(detections_tmp)
56
+ MaxValues.append(max_value)
57
+ indexes.append(max_index)
58
+ # detections_NMS[max_index,i] = max_value
59
+
60
+ nms_from = int(np.maximum(-(window/2)+max_index,0))
61
+ nms_to = int(np.minimum(max_index+int(window/2), len(detections_tmp)))
62
+ detections_tmp[nms_from:nms_to] = -1
63
+
64
+ return np.transpose([indexes, MaxValues])
65
+
66
+
67
+ def get_json_data(info):
68
+ """Create a dict that is the content of a json file.
69
+ The dict contains the keys:
70
+ -Url which is the name of the feature/video file/game.
71
+ -predictions that will contain a list of predictions.
72
+ """
73
+ json_data = dict()
74
+ json_data["Url"] = info
75
+ json_data["predictions"] = list()
76
+ return json_data
77
+
78
+
79
+ def get_prediction_data(
80
+ calf,
81
+ frame_index,
82
+ framerate,
83
+ class_index=None,
84
+ confidence=None,
85
+ half=None,
86
+ l=None,
87
+ version=None,
88
+ half_1=None,
89
+ runner="runner_JSON",
90
+ inverse_event_dictionary=None,
91
+ ):
92
+ """Create and return a dict that represents data for an event containing the time at which the event occurs, the name of the event, the position (temporal) and the confidence.
93
+ If data comes from SN data set modules, which half is also included.
94
+ Args:
95
+ calf (bool): Whether it is for the calf method.
96
+ frame_index (int).
97
+ framerate (int).
98
+ class_index (int): Index of the class with which we will retrieve the name of the class.
99
+ Default: None.
100
+ confidence (float): The confidence for the prediction.
101
+ Default: None.
102
+ half (int): The half.
103
+ Default: None.
104
+ l (int): Index of the class with which we will retrieve the name of the class. Used if it is non calf method.
105
+ version (int): The version of data for the SN datasets if used.
106
+ Default: None.
107
+ half_1
108
+ runner (string): Which runner is used. "runner_JSON" if data comes from json, "runner_pooling" or ""runner_CALF" if data comes from SN datasets modules.
109
+ The difference between the first one and the others is that for the first one, we do not include the notion of half and the dict of classes is given.
110
+ Default: "runner_JSON".
111
+ inverse_event_dictionary (dict): Mapping between indexes and classes names. Needed if runner_JSON
112
+ Default: None.
113
+ """
114
+ seconds = int((frame_index // framerate) % 60)
115
+ minutes = int((frame_index // framerate) // 60)
116
+ # print(frame_index,framerate)
117
+ prediction_data = dict()
118
+ if runner == "runner_JSON":
119
+ prediction_data["gameTime"] = f"{minutes:02.0f}:{seconds:02.0f}"
120
+ # print(f"{minutes:02.0f}:{seconds:02.0f}")
121
+ else:
122
+ prediction_data["half"] = str(1 if half_1 else 2) if calf else str(half + 1)
123
+ prediction_data["gameTime"] = (
124
+ (str(1 if half_1 else 2) + " - " + str(minutes) + ":" + str(seconds))
125
+ if calf
126
+ else f"{half+1} - {minutes:02.0f}:{seconds:02.0f}"
127
+ )
128
+ if runner == "runner_JSON":
129
+ prediction_data["label"] = inverse_event_dictionary[class_index if calf else l]
130
+ else:
131
+ prediction_data["label"] = (
132
+ INVERSE_EVENT_DICTIONARY_V2[class_index if calf else l]
133
+ if version == 2
134
+ else INVERSE_EVENT_DICTIONARY_V1[l]
135
+ )
136
+ prediction_data["position"] = int((frame_index / framerate) * 1000)
137
+ prediction_data["confidence"] = confidence
138
+
139
+ return prediction_data
140
+
141
+
142
+ def zipResults(zip_path, target_dir, filename="results_spotting.json"):
143
+ """Zip a folder of predictions into a zip file."""
144
+ zipobj = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED)
145
+ rootlen = len(target_dir) + 1
146
+ for base, dirs, files in os.walk(target_dir):
147
+ for file in files:
148
+ if file == filename:
149
+ fn = os.path.join(base, file)
150
+ zipobj.write(fn, fn[rootlen:])
151
+
152
+
153
+ def NMS(detections, delta):
154
+
155
+ # Array to put the results of the NMS
156
+ detections_tmp = np.copy(detections)
157
+ detections_NMS = np.zeros(detections.shape) - 1
158
+
159
+ # Loop over all classes
160
+ for i in np.arange(detections.shape[-1]):
161
+ # Stopping condition
162
+ while np.max(detections_tmp[:, i]) >= 0:
163
+
164
+ # Get the max remaining index and value
165
+ max_value = np.max(detections_tmp[:, i])
166
+ max_index = np.argmax(detections_tmp[:, i])
167
+
168
+ detections_NMS[max_index, i] = max_value
169
+
170
+ detections_tmp[
171
+ int(np.maximum(-(delta / 2) + max_index, 0)) : int(
172
+ np.minimum(max_index + int(delta / 2), detections.shape[0])
173
+ ),
174
+ i,
175
+ ] = -1
176
+
177
+ return detections_NMS
178
+
179
+
180
+ def predictions2json(
181
+ predictions_half_1, predictions_half_2, json_data, output_path, framerate=2
182
+ ):
183
+ """Construct a list of dict with informations for each prediction, add the list to a json object and saves the json file. Used for runner_CALF.
184
+
185
+ Args:
186
+ predictions_half_1: Contains an array of predictions for the first half.
187
+ predictions_half_2: Contains an array of predictions for the first half.
188
+ json_data: The pseudo json object in which we will add predictions.
189
+ output_path: The path of the json file.
190
+ framerate (int).
191
+ Default: 2.
192
+ """
193
+ frames_half_1, class_half_1 = np.where(predictions_half_1 >= 0)
194
+ frames_half_2, class_half_2 = np.where(predictions_half_2 >= 0)
195
+
196
+ # json_data = get_json_data(game_info)
197
+
198
+ for frame_index, class_index in zip(frames_half_1, class_half_1):
199
+
200
+ confidence = predictions_half_1[frame_index, class_index]
201
+
202
+ json_data["predictions"].append(
203
+ get_prediction_data(
204
+ True,
205
+ frame_index,
206
+ framerate,
207
+ class_index=class_index,
208
+ confidence=confidence,
209
+ version=2,
210
+ half_1=True,
211
+ runner="runner_CALF",
212
+ )
213
+ )
214
+
215
+ for frame_index, class_index in zip(frames_half_2, class_half_2):
216
+
217
+ confidence = predictions_half_2[frame_index, class_index]
218
+
219
+ json_data["predictions"].append(
220
+ get_prediction_data(
221
+ True,
222
+ frame_index,
223
+ framerate,
224
+ class_index=class_index,
225
+ confidence=confidence,
226
+ version=2,
227
+ half_1=False,
228
+ runner="runner_CALF",
229
+ )
230
+ )
231
+
232
+ with open(output_path, "w") as output_file:
233
+ json.dump(json_data, output_file, indent=4)
234
+ return json_data
235
+
236
+
237
+ def predictions2json_runnerjson(
238
+ predictions_video,
239
+ json_data,
240
+ output_path,
241
+ framerate=2,
242
+ inverse_event_dictionary=None,
243
+ ):
244
+ """Construct a list of dict with informations for each prediction, add the list to a json object and saves the json file. Used for runner_JSON.
245
+
246
+ Args:
247
+ predictions_video: Contains an array of predictions for the first half.
248
+ json_data: The pseudo json object in which we will add predictions.
249
+ output_path: The path of the json file.
250
+ framerate (int).
251
+ Default: 2.
252
+ inverse_event_dictionary (dict): Mapping between indexes and classes name.
253
+ """
254
+ frames_video, class_video = np.where(predictions_video >= 0)
255
+
256
+ for frame_index, class_index in zip(frames_video, class_video):
257
+
258
+ confidence = predictions_video[frame_index, class_index]
259
+
260
+ json_data["predictions"].append(
261
+ get_prediction_data(
262
+ True,
263
+ frame_index,
264
+ framerate,
265
+ class_index=class_index,
266
+ confidence=confidence,
267
+ version=2,
268
+ half_1=True,
269
+ runner="runner_JSON",
270
+ inverse_event_dictionary=inverse_event_dictionary,
271
+ )
272
+ )
273
+
274
+ with open(output_path, "w") as output_file:
275
+ json.dump(json_data, output_file, indent=4)
276
+ return json_data