opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
|
|
3
|
+
Kayvon Fatahalian
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without modification,
|
|
6
|
+
are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation and/or
|
|
13
|
+
other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its contributors
|
|
16
|
+
may be used to endorse or promote products derived from this software without
|
|
17
|
+
specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
20
|
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
21
|
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
|
23
|
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
24
|
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
25
|
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
|
26
|
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
27
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
28
|
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
29
|
+
"""
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn.functional as F
|
|
32
|
+
import torch.nn as nn
|
|
33
|
+
|
|
34
|
+
from opensportslib.models.utils.modules import *
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_head(cfg, default_args=None):
|
|
38
|
+
"""Build a head from config dict.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
42
|
+
default_args (dict | None, optional): Default initialization arguments.
|
|
43
|
+
Default: None.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
head: The constructed head.
|
|
47
|
+
"""
|
|
48
|
+
if cfg.type == "TrackingClassifier":
|
|
49
|
+
head = TrackingClassifierHead(
|
|
50
|
+
input_dim=default_args["input_dim"],
|
|
51
|
+
hidden_dim=cfg.hidden_dim,
|
|
52
|
+
num_classes=cfg.num_classes,
|
|
53
|
+
dropout=cfg.dropout
|
|
54
|
+
)
|
|
55
|
+
elif cfg.type == "LinearLayer":
|
|
56
|
+
head = LinearLayerHead(input_dim=cfg.input_dim, output_dim=cfg.num_classes + 1)
|
|
57
|
+
elif cfg.type == "SpottingCALF":
|
|
58
|
+
head = SpottingCALFHead(
|
|
59
|
+
num_classes=cfg.num_classes,
|
|
60
|
+
dim_capsule=cfg.dim_capsule,
|
|
61
|
+
num_detections=cfg.num_detections,
|
|
62
|
+
chunk_size=cfg.chunk_size,
|
|
63
|
+
)
|
|
64
|
+
elif cfg.type in ["", "gru", "deeper_gru", "mstcn", "asformer"]:
|
|
65
|
+
head = TemporalE2EHead(cfg.type, cfg.feat_dim, cfg.num_classes)
|
|
66
|
+
elif cfg.type == "MV_LinearLayer":
|
|
67
|
+
head = MVHead(input_dim=cfg.feat_dim, output_dim=cfg.num_classes)
|
|
68
|
+
else:
|
|
69
|
+
head = None
|
|
70
|
+
|
|
71
|
+
return head
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class MVHead(nn.Module):
|
|
75
|
+
def __init__(self, input_dim, output_dim):
|
|
76
|
+
super().__init__()
|
|
77
|
+
print("Inside MVHead")
|
|
78
|
+
self.fc = nn.Sequential(
|
|
79
|
+
nn.LayerNorm(input_dim),
|
|
80
|
+
nn.Linear(input_dim, input_dim),
|
|
81
|
+
nn.Linear(input_dim, output_dim)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
pred = self.fc(x)
|
|
86
|
+
return pred
|
|
87
|
+
|
|
88
|
+
class TrackingClassifierHead(nn.Module):
|
|
89
|
+
"""Classification head for tracking data."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, input_dim, hidden_dim, num_classes, dropout=0.1):
|
|
92
|
+
super().__init__()
|
|
93
|
+
|
|
94
|
+
self.classifier = nn.Sequential(
|
|
95
|
+
nn.Dropout(dropout),
|
|
96
|
+
nn.Linear(input_dim, hidden_dim),
|
|
97
|
+
nn.ReLU(),
|
|
98
|
+
nn.Dropout(dropout * 0.5),
|
|
99
|
+
nn.Linear(hidden_dim, num_classes)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def forward(self, x):
|
|
103
|
+
return self.classifier(x)
|
|
104
|
+
|
|
105
|
+
class TemporalE2EHead(nn.Module):
|
|
106
|
+
def __init__(self, temporal_arch, feat_dim, num_classes):
|
|
107
|
+
super().__init__()
|
|
108
|
+
# Prevent the GRU params from going too big (cap it at a RegNet-Y 800MF)
|
|
109
|
+
MAX_GRU_HIDDEN_DIM = 768
|
|
110
|
+
if "gru" in temporal_arch:
|
|
111
|
+
hidden_dim = feat_dim
|
|
112
|
+
if hidden_dim > MAX_GRU_HIDDEN_DIM:
|
|
113
|
+
hidden_dim = MAX_GRU_HIDDEN_DIM
|
|
114
|
+
print("Clamped GRU hidden dim: {} -> {}".format(feat_dim, hidden_dim))
|
|
115
|
+
if temporal_arch in ("gru", "deeper_gru"):
|
|
116
|
+
self._pred_fine = GRUPrediction(
|
|
117
|
+
feat_dim,
|
|
118
|
+
num_classes,
|
|
119
|
+
hidden_dim,
|
|
120
|
+
num_layers=3 if temporal_arch[0] == "d" else 1,
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
raise NotImplementedError(temporal_arch)
|
|
124
|
+
elif temporal_arch == "mstcn":
|
|
125
|
+
self._pred_fine = TCNPrediction(feat_dim, num_classes, 3)
|
|
126
|
+
elif temporal_arch == "asformer":
|
|
127
|
+
self._pred_fine = ASFormerPrediction(feat_dim, num_classes, 3)
|
|
128
|
+
elif temporal_arch == "":
|
|
129
|
+
self._pred_fine = FCPrediction(feat_dim, num_classes)
|
|
130
|
+
|
|
131
|
+
def forward(self, inputs):
|
|
132
|
+
return self._pred_fine(inputs)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class LinearLayerHead(torch.nn.Module):
|
|
136
|
+
def __init__(self, input_dim, output_dim):
|
|
137
|
+
super(LinearLayerHead, self).__init__()
|
|
138
|
+
self.input_dim = input_dim
|
|
139
|
+
self.output_dim = output_dim
|
|
140
|
+
|
|
141
|
+
self.drop = torch.nn.Dropout(p=0.4)
|
|
142
|
+
self.head = torch.nn.Linear(input_dim, output_dim)
|
|
143
|
+
self.sigm = torch.nn.Sigmoid()
|
|
144
|
+
|
|
145
|
+
def forward(self, inputs):
|
|
146
|
+
return self.sigm(self.head(self.drop(inputs)))
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class SpottingCALFHead(torch.nn.Module):
|
|
150
|
+
def __init__(self, num_classes, dim_capsule, num_detections, chunk_size):
|
|
151
|
+
super(SpottingCALFHead, self).__init__()
|
|
152
|
+
|
|
153
|
+
self.num_classes = num_classes
|
|
154
|
+
self.dim_capsule = dim_capsule
|
|
155
|
+
self.num_detections = num_detections
|
|
156
|
+
self.chunk_size = chunk_size
|
|
157
|
+
|
|
158
|
+
# -------------------
|
|
159
|
+
# detection module
|
|
160
|
+
# -------------------
|
|
161
|
+
self.max_pool_spot = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1))
|
|
162
|
+
self.kernel_spot_size = 3
|
|
163
|
+
self.pad_spot_1 = nn.ZeroPad2d(
|
|
164
|
+
(
|
|
165
|
+
0,
|
|
166
|
+
0,
|
|
167
|
+
(self.kernel_spot_size - 1) // 2,
|
|
168
|
+
self.kernel_spot_size - 1 - (self.kernel_spot_size - 1) // 2,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
self.conv_spot_1 = nn.Conv2d(
|
|
172
|
+
in_channels=num_classes * (dim_capsule + 1),
|
|
173
|
+
out_channels=32,
|
|
174
|
+
kernel_size=(self.kernel_spot_size, 1),
|
|
175
|
+
)
|
|
176
|
+
self.max_pool_spot_1 = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1))
|
|
177
|
+
self.pad_spot_2 = nn.ZeroPad2d(
|
|
178
|
+
(
|
|
179
|
+
0,
|
|
180
|
+
0,
|
|
181
|
+
(self.kernel_spot_size - 1) // 2,
|
|
182
|
+
self.kernel_spot_size - 1 - (self.kernel_spot_size - 1) // 2,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
self.conv_spot_2 = nn.Conv2d(
|
|
186
|
+
in_channels=32, out_channels=16, kernel_size=(self.kernel_spot_size, 1)
|
|
187
|
+
)
|
|
188
|
+
self.max_pool_spot_2 = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1))
|
|
189
|
+
|
|
190
|
+
# Confidence branch
|
|
191
|
+
self.conv_conf = nn.Conv2d(
|
|
192
|
+
in_channels=16 * (chunk_size // 8 - 1),
|
|
193
|
+
out_channels=self.num_detections * 2,
|
|
194
|
+
kernel_size=(1, 1),
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Class branch
|
|
198
|
+
self.conv_class = nn.Conv2d(
|
|
199
|
+
in_channels=16 * (chunk_size // 8 - 1),
|
|
200
|
+
out_channels=self.num_detections * self.num_classes,
|
|
201
|
+
kernel_size=(1, 1),
|
|
202
|
+
)
|
|
203
|
+
self.softmax = nn.Softmax(dim=-1)
|
|
204
|
+
|
|
205
|
+
def forward(self, conv_seg, output_segmentation):
|
|
206
|
+
# ---------------
|
|
207
|
+
# Spotting module
|
|
208
|
+
# ---------------
|
|
209
|
+
|
|
210
|
+
# Concatenation of the segmentation score to the capsules
|
|
211
|
+
output_segmentation_reverse = 1 - output_segmentation
|
|
212
|
+
# print("Output_segmentation_reverse size: ", output_segmentation_reverse.size())
|
|
213
|
+
|
|
214
|
+
output_segmentation_reverse_reshaped = output_segmentation_reverse.unsqueeze(2)
|
|
215
|
+
# print("Output_segmentation_reverse_reshaped size: ", output_segmentation_reverse_reshaped.size())
|
|
216
|
+
|
|
217
|
+
output_segmentation_reverse_reshaped_permutted = (
|
|
218
|
+
output_segmentation_reverse_reshaped.permute(0, 3, 1, 2)
|
|
219
|
+
)
|
|
220
|
+
# print("Output_segmentation_reverse_reshaped_permutted size: ", output_segmentation_reverse_reshaped_permutted.size())
|
|
221
|
+
|
|
222
|
+
concatenation_2 = torch.cat(
|
|
223
|
+
(conv_seg, output_segmentation_reverse_reshaped_permutted), dim=1
|
|
224
|
+
)
|
|
225
|
+
# print("Concatenation_2 size: ", concatenation_2.size())
|
|
226
|
+
|
|
227
|
+
conv_spot = self.max_pool_spot(F.relu(concatenation_2))
|
|
228
|
+
# print("Conv_spot size: ", conv_spot.size())
|
|
229
|
+
|
|
230
|
+
conv_spot_1 = F.relu(self.conv_spot_1(self.pad_spot_1(conv_spot)))
|
|
231
|
+
# print("Conv_spot_1 size: ", conv_spot_1.size())
|
|
232
|
+
|
|
233
|
+
conv_spot_1_pooled = self.max_pool_spot_1(conv_spot_1)
|
|
234
|
+
# print("Conv_spot_1_pooled size: ", conv_spot_1_pooled.size())
|
|
235
|
+
|
|
236
|
+
conv_spot_2 = F.relu(self.conv_spot_2(self.pad_spot_2(conv_spot_1_pooled)))
|
|
237
|
+
# print("Conv_spot_2 size: ", conv_spot_2.size())
|
|
238
|
+
|
|
239
|
+
conv_spot_2_pooled = self.max_pool_spot_2(conv_spot_2)
|
|
240
|
+
# print("Conv_spot_2_pooled size: ", conv_spot_2_pooled.size())
|
|
241
|
+
|
|
242
|
+
spotting_reshaped = conv_spot_2_pooled.view(
|
|
243
|
+
conv_spot_2_pooled.size()[0], -1, 1, 1
|
|
244
|
+
)
|
|
245
|
+
# print("Spotting_reshape size: ", spotting_reshaped.size())
|
|
246
|
+
|
|
247
|
+
# Confindence branch
|
|
248
|
+
conf_pred = torch.sigmoid(
|
|
249
|
+
self.conv_conf(spotting_reshaped).view(
|
|
250
|
+
spotting_reshaped.shape[0], self.num_detections, 2
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
# print("Conf_pred size: ", conf_pred.size())
|
|
254
|
+
|
|
255
|
+
# Class branch
|
|
256
|
+
conf_class = self.softmax(
|
|
257
|
+
self.conv_class(spotting_reshaped).view(
|
|
258
|
+
spotting_reshaped.shape[0], self.num_detections, self.num_classes
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
# print("Conf_class size: ", conf_class.size())
|
|
262
|
+
|
|
263
|
+
output_spotting = torch.cat((conf_pred, conf_class), dim=-1)
|
|
264
|
+
# print("Output_spotting size: ", output_spotting.size())
|
|
265
|
+
|
|
266
|
+
return output_spotting
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from opensportslib.core.utils.data import batch_tensor, unbatch_tensor
|
|
4
|
+
|
|
5
|
+
def build_neck(cfg, default_args=None):
|
|
6
|
+
if cfg.type == "MV_Aggregate":
|
|
7
|
+
neck = MVAggregate(
|
|
8
|
+
agr_type=cfg.agr_type,
|
|
9
|
+
model=default_args["model"],
|
|
10
|
+
feat_dim=default_args["feat_dim"],
|
|
11
|
+
lifting_net=default_args["lifting_net"] if "lifting_net" in default_args else nn.Sequential()
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
elif cfg.type == "TemporalAggregation":
|
|
15
|
+
neck = TemporalAggregation(
|
|
16
|
+
temporal_type=cfg.agr_type,
|
|
17
|
+
hidden_dim=cfg.hidden_dim,
|
|
18
|
+
window_size=default_args["window_size"],
|
|
19
|
+
dropout=cfg.dropout,
|
|
20
|
+
use_position_encoding=getattr(cfg, "use_position_encoding", False),
|
|
21
|
+
num_attention_heads=getattr(cfg, "num_attention_heads", 4),
|
|
22
|
+
lstm_dropout=getattr(cfg, "lstm_dropout", 0.1)
|
|
23
|
+
)
|
|
24
|
+
else:
|
|
25
|
+
raise ValueError(f"Unknown neck type: {cfg.type}")
|
|
26
|
+
return neck
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MVAggregate(nn.Module):
|
|
30
|
+
def __init__(self, agr_type, model, feat_dim, lifting_net=nn.Sequential()):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.agr_type = agr_type
|
|
33
|
+
self.model = model
|
|
34
|
+
self.feat_dim = feat_dim
|
|
35
|
+
self.lifting_net = lifting_net
|
|
36
|
+
print("Inside NECK BUILDER - AGR TYPE:", self.agr_type)
|
|
37
|
+
|
|
38
|
+
if self.agr_type == "max":
|
|
39
|
+
self.aggregation_model = ViewMaxAggregate(model=model, lifting_net=lifting_net)
|
|
40
|
+
elif self.agr_type == "mean":
|
|
41
|
+
self.aggregation_model = ViewAvgAggregate(model=model, lifting_net=lifting_net)
|
|
42
|
+
else:
|
|
43
|
+
# avg
|
|
44
|
+
self.aggregation_model = WeightedAggregate(model=model, feat_dim=feat_dim, lifting_net=lifting_net)
|
|
45
|
+
|
|
46
|
+
self.inter = nn.Sequential(
|
|
47
|
+
nn.LayerNorm(feat_dim),
|
|
48
|
+
nn.Linear(feat_dim, feat_dim),
|
|
49
|
+
nn.Linear(feat_dim, feat_dim),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def forward(self, mvimages):
|
|
53
|
+
pooled_view, attention = self.aggregation_model(mvimages)
|
|
54
|
+
inter = self.inter(pooled_view)
|
|
55
|
+
return inter, attention
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class WeightedAggregate(nn.Module):
|
|
60
|
+
def __init__(self, model, feat_dim, lifting_net=nn.Sequential()):
|
|
61
|
+
super().__init__()
|
|
62
|
+
self.model = model
|
|
63
|
+
self.lifting_net = lifting_net
|
|
64
|
+
self.feature_dim = feat_dim
|
|
65
|
+
|
|
66
|
+
r1 = -1
|
|
67
|
+
r2 = 1
|
|
68
|
+
self.attention_weights = nn.Parameter((r1 - r2) * torch.rand(feat_dim, feat_dim) + r2)
|
|
69
|
+
|
|
70
|
+
self.normReLu = nn.Sequential(
|
|
71
|
+
nn.LayerNorm(feat_dim),
|
|
72
|
+
nn.ReLU()
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self.relu = nn.ReLU()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def forward(self, mvimages):
|
|
80
|
+
B, V, C, D, H, W = mvimages.shape # Batch, Views, Channel, Depth, Height, Width
|
|
81
|
+
aux = self.lifting_net(unbatch_tensor(self.model(batch_tensor(mvimages, dim=1, squeeze=True)), B, dim=1, unsqueeze=True))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
##################### VIEW ATTENTION #####################
|
|
85
|
+
|
|
86
|
+
# S = source length
|
|
87
|
+
# N = batch size
|
|
88
|
+
# E = embedding dimension
|
|
89
|
+
# L = target length
|
|
90
|
+
|
|
91
|
+
aux = torch.matmul(aux, self.attention_weights)
|
|
92
|
+
# Dimension S, E for two views (2,512)
|
|
93
|
+
|
|
94
|
+
# Dimension N, S, E
|
|
95
|
+
aux_t = aux.permute(0, 2, 1)
|
|
96
|
+
|
|
97
|
+
prod = torch.bmm(aux, aux_t)
|
|
98
|
+
relu_res = self.relu(prod)
|
|
99
|
+
|
|
100
|
+
aux_sum = torch.sum(torch.reshape(relu_res, (B, V*V)).T, dim=0).unsqueeze(0)
|
|
101
|
+
final_attention_weights = torch.div(torch.reshape(relu_res, (B, V*V)).T, aux_sum.squeeze(0))
|
|
102
|
+
final_attention_weights = final_attention_weights.T
|
|
103
|
+
|
|
104
|
+
final_attention_weights = torch.reshape(final_attention_weights, (B, V, V))
|
|
105
|
+
|
|
106
|
+
final_attention_weights = torch.sum(final_attention_weights, 1)
|
|
107
|
+
|
|
108
|
+
output = torch.mul(aux.squeeze(), final_attention_weights.unsqueeze(-1))
|
|
109
|
+
|
|
110
|
+
output = torch.sum(output, 1)
|
|
111
|
+
|
|
112
|
+
return output.squeeze(), final_attention_weights
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ViewMaxAggregate(nn.Module):
|
|
116
|
+
def __init__(self, model, lifting_net=nn.Sequential()):
|
|
117
|
+
super().__init__()
|
|
118
|
+
self.model = model
|
|
119
|
+
self.lifting_net = lifting_net
|
|
120
|
+
|
|
121
|
+
def forward(self, mvimages):
|
|
122
|
+
B, V, C, D, H, W = mvimages.shape # Batch, Views, Channel, Depth, Height, Width
|
|
123
|
+
aux = self.lifting_net(unbatch_tensor(self.model(batch_tensor(mvimages, dim=1, squeeze=True)), B, dim=1, unsqueeze=True))
|
|
124
|
+
pooled_view = torch.max(aux, dim=1)[0]
|
|
125
|
+
return pooled_view.squeeze(), aux
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ViewAvgAggregate(nn.Module):
|
|
129
|
+
def __init__(self, model, lifting_net=nn.Sequential()):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.model = model
|
|
132
|
+
self.lifting_net = lifting_net
|
|
133
|
+
|
|
134
|
+
def forward(self, mvimages):
|
|
135
|
+
B, V, C, D, H, W = mvimages.shape # Batch, Views, Channel, Depth, Height, Width
|
|
136
|
+
aux = self.lifting_net(unbatch_tensor(self.model(batch_tensor(mvimages, dim=1, squeeze=True)), B, dim=1, unsqueeze=True))
|
|
137
|
+
pooled_view = torch.mean(aux, dim=1)
|
|
138
|
+
return pooled_view.squeeze(), aux
|
|
139
|
+
|
|
140
|
+
class TemporalAggregation(nn.Module):
|
|
141
|
+
def __init__(self, temporal_type, hidden_dim, window_size, dropout=0.1,
|
|
142
|
+
use_position_encoding=False, num_attention_heads=4, lstm_dropout=0.3):
|
|
143
|
+
super().__init__()
|
|
144
|
+
|
|
145
|
+
self.num_attention_heads = num_attention_heads
|
|
146
|
+
self.temporal_type = temporal_type
|
|
147
|
+
self.hidden_dim = hidden_dim
|
|
148
|
+
self.feat_dim = hidden_dim * 2 if temporal_type == "bilstm" else hidden_dim
|
|
149
|
+
self.use_position_encoding = use_position_encoding
|
|
150
|
+
self.lstm_dropout = lstm_dropout
|
|
151
|
+
|
|
152
|
+
# learnable temporal position encoding (only used when explicitly enabled)
|
|
153
|
+
if self.use_position_encoding:
|
|
154
|
+
self.temporal_position_encoding = nn.Parameter(
|
|
155
|
+
torch.randn(1, window_size, hidden_dim) * 0.02
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# build temporal module
|
|
159
|
+
self.temporal = self._build_temporal_module(temporal_type, hidden_dim, dropout)
|
|
160
|
+
|
|
161
|
+
def _build_temporal_module(self, temporal_type, hidden_dim, dropout):
|
|
162
|
+
if temporal_type == 'bilstm':
|
|
163
|
+
return nn.LSTM(
|
|
164
|
+
hidden_dim, hidden_dim, num_layers=2,
|
|
165
|
+
batch_first=True, bidirectional=True, dropout=self.lstm_dropout
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
elif temporal_type == 'tcn':
|
|
169
|
+
return nn.Sequential(
|
|
170
|
+
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
|
171
|
+
nn.ReLU(),
|
|
172
|
+
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
elif temporal_type == 'attention':
|
|
176
|
+
return nn.MultiheadAttention(
|
|
177
|
+
hidden_dim, num_heads=self.num_attention_heads,
|
|
178
|
+
dropout=dropout, batch_first=True
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
else: # avgpool, maxpool
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
def forward(self, x):
|
|
185
|
+
seq_len = x.size(1)
|
|
186
|
+
|
|
187
|
+
if self.use_position_encoding:
|
|
188
|
+
x = x + self.temporal_position_encoding[:, :seq_len, :]
|
|
189
|
+
|
|
190
|
+
if self.temporal_type == 'avgpool':
|
|
191
|
+
x = torch.mean(x, dim=1)
|
|
192
|
+
|
|
193
|
+
elif self.temporal_type == 'maxpool':
|
|
194
|
+
x = torch.max(x, dim=1)[0]
|
|
195
|
+
|
|
196
|
+
elif self.temporal_type == 'tcn':
|
|
197
|
+
x = x.permute(0, 2, 1)
|
|
198
|
+
x = self.temporal(x)
|
|
199
|
+
x = x.permute(0, 2, 1)
|
|
200
|
+
x = torch.max(x, dim=1)[0]
|
|
201
|
+
|
|
202
|
+
elif self.temporal_type == 'attention':
|
|
203
|
+
x, _ = self.temporal(x, x, x)
|
|
204
|
+
x = torch.max(x, dim=1)[0]
|
|
205
|
+
|
|
206
|
+
elif self.temporal_type == 'bilstm':
|
|
207
|
+
lstm_out, _ = self.temporal(x)
|
|
208
|
+
x = torch.max(lstm_out, dim=1)[0]
|
|
209
|
+
|
|
210
|
+
return x
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
|
|
3
|
+
Kayvon Fatahalian
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without modification,
|
|
6
|
+
are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation and/or
|
|
13
|
+
other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its contributors
|
|
16
|
+
may be used to endorse or promote products derived from this software without
|
|
17
|
+
specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
20
|
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
21
|
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
|
23
|
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
24
|
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
25
|
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
|
26
|
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
27
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
28
|
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
29
|
+
"""
|
|
30
|
+
import abc
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn as nn
|
|
33
|
+
import torch.nn.functional as F
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ABCModel:
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def get_optimizer(self, opt_args):
|
|
40
|
+
raise NotImplementedError()
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def epoch(self, loader, **kwargs):
|
|
44
|
+
raise NotImplementedError()
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def predict(self, seq):
|
|
48
|
+
raise NotImplementedError()
|
|
49
|
+
|
|
50
|
+
@abc.abstractmethod
|
|
51
|
+
def state_dict(self):
|
|
52
|
+
raise NotImplementedError()
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def load(self, state_dict):
|
|
56
|
+
raise NotImplementedError()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BaseRGBModel(ABCModel):
|
|
60
|
+
|
|
61
|
+
def get_optimizer(self, opt_args):
|
|
62
|
+
return torch.optim.AdamW(self._get_params(), **opt_args), (
|
|
63
|
+
torch.cuda.amp.GradScaler() if self.device == "cuda" else None
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
""" Assume there is a self._model """
|
|
67
|
+
|
|
68
|
+
def _get_params(self):
|
|
69
|
+
return list(self._model.parameters())
|
|
70
|
+
|
|
71
|
+
def state_dict(self):
|
|
72
|
+
if isinstance(self._model, nn.DataParallel):
|
|
73
|
+
return self._model.module.state_dict()
|
|
74
|
+
return self._model.state_dict()
|
|
75
|
+
|
|
76
|
+
def load(self, state_dict):
|
|
77
|
+
if isinstance(self._model, nn.DataParallel):
|
|
78
|
+
self._model.module.load_state_dict(state_dict)
|
|
79
|
+
else:
|
|
80
|
+
self._model.load_state_dict(state_dict)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def step(optimizer, scaler, loss, lr_scheduler=None, backward_only=False):
|
|
84
|
+
"""
|
|
85
|
+
Perform a backward pass, and optionally update the model parameters and learning rate scheduler.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
optimizer (torch.optim.Optimizer): The optimizer to update model parameters.
|
|
89
|
+
scaler (torch.cuda.amp.GradScaler): The gradient scaler for mixed precision training.
|
|
90
|
+
loss (torch.Tensor): The computed loss to backpropagate.
|
|
91
|
+
lr_scheduler : The learning rate scheduler.
|
|
92
|
+
Default: None.
|
|
93
|
+
backward_only (bool): If True, only perform the backward pass.
|
|
94
|
+
Default: False.
|
|
95
|
+
"""
|
|
96
|
+
if scaler is None:
|
|
97
|
+
loss.backward()
|
|
98
|
+
else:
|
|
99
|
+
scaler.scale(loss).backward()
|
|
100
|
+
|
|
101
|
+
if not backward_only:
|
|
102
|
+
if scaler is None:
|
|
103
|
+
optimizer.step()
|
|
104
|
+
else:
|
|
105
|
+
scaler.step(optimizer)
|
|
106
|
+
scaler.update()
|
|
107
|
+
if lr_scheduler is not None:
|
|
108
|
+
lr_scheduler.step()
|
|
109
|
+
optimizer.zero_grad()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class SingleStageGRU(nn.Module):
|
|
113
|
+
|
|
114
|
+
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=5):
|
|
115
|
+
super(SingleStageGRU, self).__init__()
|
|
116
|
+
self.backbone = nn.GRU(
|
|
117
|
+
in_dim,
|
|
118
|
+
hidden_dim,
|
|
119
|
+
num_layers=num_layers,
|
|
120
|
+
batch_first=True,
|
|
121
|
+
bidirectional=True,
|
|
122
|
+
)
|
|
123
|
+
self.fc_out = nn.Sequential(
|
|
124
|
+
nn.BatchNorm1d(2 * hidden_dim),
|
|
125
|
+
nn.Dropout(),
|
|
126
|
+
nn.Linear(2 * hidden_dim, out_dim),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def forward(self, x):
|
|
130
|
+
batch_size, clip_len, _ = x.shape
|
|
131
|
+
x, _ = self.backbone(x)
|
|
132
|
+
x = self.fc_out(x.reshape(-1, x.shape[-1]))
|
|
133
|
+
return x.view(batch_size, clip_len, -1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class SingleStageTCN(nn.Module):
|
|
137
|
+
|
|
138
|
+
class DilatedResidualLayer(nn.Module):
|
|
139
|
+
def __init__(self, dilation, in_channels, out_channels):
|
|
140
|
+
super(SingleStageTCN.DilatedResidualLayer, self).__init__()
|
|
141
|
+
self.conv_dilated = nn.Conv1d(
|
|
142
|
+
in_channels, out_channels, 3, padding=dilation, dilation=dilation
|
|
143
|
+
)
|
|
144
|
+
self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
|
|
145
|
+
self.dropout = nn.Dropout()
|
|
146
|
+
|
|
147
|
+
def forward(self, x, mask):
|
|
148
|
+
out = F.relu(self.conv_dilated(x))
|
|
149
|
+
out = self.conv_1x1(out)
|
|
150
|
+
out = self.dropout(out)
|
|
151
|
+
return (x + out) * mask[:, 0:1, :]
|
|
152
|
+
|
|
153
|
+
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dilate):
|
|
154
|
+
super(SingleStageTCN, self).__init__()
|
|
155
|
+
self.conv_1x1 = nn.Conv1d(in_dim, hidden_dim, 1)
|
|
156
|
+
self.layers = nn.ModuleList(
|
|
157
|
+
[
|
|
158
|
+
SingleStageTCN.DilatedResidualLayer(
|
|
159
|
+
2**i if dilate else 1, hidden_dim, hidden_dim
|
|
160
|
+
)
|
|
161
|
+
for i in range(num_layers)
|
|
162
|
+
]
|
|
163
|
+
)
|
|
164
|
+
self.conv_out = nn.Conv1d(hidden_dim, out_dim, 1)
|
|
165
|
+
|
|
166
|
+
def forward(self, x, m=None):
|
|
167
|
+
batch_size, clip_len, _ = x.shape
|
|
168
|
+
if m is None:
|
|
169
|
+
m = torch.ones((batch_size, 1, clip_len), device=x.device)
|
|
170
|
+
else:
|
|
171
|
+
m = m.permute(0, 2, 1)
|
|
172
|
+
x = self.conv_1x1(x.permute(0, 2, 1))
|
|
173
|
+
for layer in self.layers:
|
|
174
|
+
x = layer(x, m)
|
|
175
|
+
x = self.conv_out(x) * m[:, 0:1, :]
|
|
176
|
+
return x.permute(0, 2, 1)
|
|
File without changes
|