opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,266 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.nn as nn
33
+
34
+ from opensportslib.models.utils.modules import *
35
+
36
+
37
+ def build_head(cfg, default_args=None):
38
+ """Build a head from config dict.
39
+
40
+ Args:
41
+ cfg (dict): Config dict. It should at least contain the key "type".
42
+ default_args (dict | None, optional): Default initialization arguments.
43
+ Default: None.
44
+
45
+ Returns:
46
+ head: The constructed head.
47
+ """
48
+ if cfg.type == "TrackingClassifier":
49
+ head = TrackingClassifierHead(
50
+ input_dim=default_args["input_dim"],
51
+ hidden_dim=cfg.hidden_dim,
52
+ num_classes=cfg.num_classes,
53
+ dropout=cfg.dropout
54
+ )
55
+ elif cfg.type == "LinearLayer":
56
+ head = LinearLayerHead(input_dim=cfg.input_dim, output_dim=cfg.num_classes + 1)
57
+ elif cfg.type == "SpottingCALF":
58
+ head = SpottingCALFHead(
59
+ num_classes=cfg.num_classes,
60
+ dim_capsule=cfg.dim_capsule,
61
+ num_detections=cfg.num_detections,
62
+ chunk_size=cfg.chunk_size,
63
+ )
64
+ elif cfg.type in ["", "gru", "deeper_gru", "mstcn", "asformer"]:
65
+ head = TemporalE2EHead(cfg.type, cfg.feat_dim, cfg.num_classes)
66
+ elif cfg.type == "MV_LinearLayer":
67
+ head = MVHead(input_dim=cfg.feat_dim, output_dim=cfg.num_classes)
68
+ else:
69
+ head = None
70
+
71
+ return head
72
+
73
+
74
+ class MVHead(nn.Module):
75
+ def __init__(self, input_dim, output_dim):
76
+ super().__init__()
77
+ print("Inside MVHead")
78
+ self.fc = nn.Sequential(
79
+ nn.LayerNorm(input_dim),
80
+ nn.Linear(input_dim, input_dim),
81
+ nn.Linear(input_dim, output_dim)
82
+ )
83
+
84
+ def forward(self, x):
85
+ pred = self.fc(x)
86
+ return pred
87
+
88
+ class TrackingClassifierHead(nn.Module):
89
+ """Classification head for tracking data."""
90
+
91
+ def __init__(self, input_dim, hidden_dim, num_classes, dropout=0.1):
92
+ super().__init__()
93
+
94
+ self.classifier = nn.Sequential(
95
+ nn.Dropout(dropout),
96
+ nn.Linear(input_dim, hidden_dim),
97
+ nn.ReLU(),
98
+ nn.Dropout(dropout * 0.5),
99
+ nn.Linear(hidden_dim, num_classes)
100
+ )
101
+
102
+ def forward(self, x):
103
+ return self.classifier(x)
104
+
105
+ class TemporalE2EHead(nn.Module):
106
+ def __init__(self, temporal_arch, feat_dim, num_classes):
107
+ super().__init__()
108
+ # Prevent the GRU params from going too big (cap it at a RegNet-Y 800MF)
109
+ MAX_GRU_HIDDEN_DIM = 768
110
+ if "gru" in temporal_arch:
111
+ hidden_dim = feat_dim
112
+ if hidden_dim > MAX_GRU_HIDDEN_DIM:
113
+ hidden_dim = MAX_GRU_HIDDEN_DIM
114
+ print("Clamped GRU hidden dim: {} -> {}".format(feat_dim, hidden_dim))
115
+ if temporal_arch in ("gru", "deeper_gru"):
116
+ self._pred_fine = GRUPrediction(
117
+ feat_dim,
118
+ num_classes,
119
+ hidden_dim,
120
+ num_layers=3 if temporal_arch[0] == "d" else 1,
121
+ )
122
+ else:
123
+ raise NotImplementedError(temporal_arch)
124
+ elif temporal_arch == "mstcn":
125
+ self._pred_fine = TCNPrediction(feat_dim, num_classes, 3)
126
+ elif temporal_arch == "asformer":
127
+ self._pred_fine = ASFormerPrediction(feat_dim, num_classes, 3)
128
+ elif temporal_arch == "":
129
+ self._pred_fine = FCPrediction(feat_dim, num_classes)
130
+
131
+ def forward(self, inputs):
132
+ return self._pred_fine(inputs)
133
+
134
+
135
+ class LinearLayerHead(torch.nn.Module):
136
+ def __init__(self, input_dim, output_dim):
137
+ super(LinearLayerHead, self).__init__()
138
+ self.input_dim = input_dim
139
+ self.output_dim = output_dim
140
+
141
+ self.drop = torch.nn.Dropout(p=0.4)
142
+ self.head = torch.nn.Linear(input_dim, output_dim)
143
+ self.sigm = torch.nn.Sigmoid()
144
+
145
+ def forward(self, inputs):
146
+ return self.sigm(self.head(self.drop(inputs)))
147
+
148
+
149
+ class SpottingCALFHead(torch.nn.Module):
150
+ def __init__(self, num_classes, dim_capsule, num_detections, chunk_size):
151
+ super(SpottingCALFHead, self).__init__()
152
+
153
+ self.num_classes = num_classes
154
+ self.dim_capsule = dim_capsule
155
+ self.num_detections = num_detections
156
+ self.chunk_size = chunk_size
157
+
158
+ # -------------------
159
+ # detection module
160
+ # -------------------
161
+ self.max_pool_spot = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1))
162
+ self.kernel_spot_size = 3
163
+ self.pad_spot_1 = nn.ZeroPad2d(
164
+ (
165
+ 0,
166
+ 0,
167
+ (self.kernel_spot_size - 1) // 2,
168
+ self.kernel_spot_size - 1 - (self.kernel_spot_size - 1) // 2,
169
+ )
170
+ )
171
+ self.conv_spot_1 = nn.Conv2d(
172
+ in_channels=num_classes * (dim_capsule + 1),
173
+ out_channels=32,
174
+ kernel_size=(self.kernel_spot_size, 1),
175
+ )
176
+ self.max_pool_spot_1 = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1))
177
+ self.pad_spot_2 = nn.ZeroPad2d(
178
+ (
179
+ 0,
180
+ 0,
181
+ (self.kernel_spot_size - 1) // 2,
182
+ self.kernel_spot_size - 1 - (self.kernel_spot_size - 1) // 2,
183
+ )
184
+ )
185
+ self.conv_spot_2 = nn.Conv2d(
186
+ in_channels=32, out_channels=16, kernel_size=(self.kernel_spot_size, 1)
187
+ )
188
+ self.max_pool_spot_2 = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1))
189
+
190
+ # Confidence branch
191
+ self.conv_conf = nn.Conv2d(
192
+ in_channels=16 * (chunk_size // 8 - 1),
193
+ out_channels=self.num_detections * 2,
194
+ kernel_size=(1, 1),
195
+ )
196
+
197
+ # Class branch
198
+ self.conv_class = nn.Conv2d(
199
+ in_channels=16 * (chunk_size // 8 - 1),
200
+ out_channels=self.num_detections * self.num_classes,
201
+ kernel_size=(1, 1),
202
+ )
203
+ self.softmax = nn.Softmax(dim=-1)
204
+
205
+ def forward(self, conv_seg, output_segmentation):
206
+ # ---------------
207
+ # Spotting module
208
+ # ---------------
209
+
210
+ # Concatenation of the segmentation score to the capsules
211
+ output_segmentation_reverse = 1 - output_segmentation
212
+ # print("Output_segmentation_reverse size: ", output_segmentation_reverse.size())
213
+
214
+ output_segmentation_reverse_reshaped = output_segmentation_reverse.unsqueeze(2)
215
+ # print("Output_segmentation_reverse_reshaped size: ", output_segmentation_reverse_reshaped.size())
216
+
217
+ output_segmentation_reverse_reshaped_permutted = (
218
+ output_segmentation_reverse_reshaped.permute(0, 3, 1, 2)
219
+ )
220
+ # print("Output_segmentation_reverse_reshaped_permutted size: ", output_segmentation_reverse_reshaped_permutted.size())
221
+
222
+ concatenation_2 = torch.cat(
223
+ (conv_seg, output_segmentation_reverse_reshaped_permutted), dim=1
224
+ )
225
+ # print("Concatenation_2 size: ", concatenation_2.size())
226
+
227
+ conv_spot = self.max_pool_spot(F.relu(concatenation_2))
228
+ # print("Conv_spot size: ", conv_spot.size())
229
+
230
+ conv_spot_1 = F.relu(self.conv_spot_1(self.pad_spot_1(conv_spot)))
231
+ # print("Conv_spot_1 size: ", conv_spot_1.size())
232
+
233
+ conv_spot_1_pooled = self.max_pool_spot_1(conv_spot_1)
234
+ # print("Conv_spot_1_pooled size: ", conv_spot_1_pooled.size())
235
+
236
+ conv_spot_2 = F.relu(self.conv_spot_2(self.pad_spot_2(conv_spot_1_pooled)))
237
+ # print("Conv_spot_2 size: ", conv_spot_2.size())
238
+
239
+ conv_spot_2_pooled = self.max_pool_spot_2(conv_spot_2)
240
+ # print("Conv_spot_2_pooled size: ", conv_spot_2_pooled.size())
241
+
242
+ spotting_reshaped = conv_spot_2_pooled.view(
243
+ conv_spot_2_pooled.size()[0], -1, 1, 1
244
+ )
245
+ # print("Spotting_reshape size: ", spotting_reshaped.size())
246
+
247
+ # Confindence branch
248
+ conf_pred = torch.sigmoid(
249
+ self.conv_conf(spotting_reshaped).view(
250
+ spotting_reshaped.shape[0], self.num_detections, 2
251
+ )
252
+ )
253
+ # print("Conf_pred size: ", conf_pred.size())
254
+
255
+ # Class branch
256
+ conf_class = self.softmax(
257
+ self.conv_class(spotting_reshaped).view(
258
+ spotting_reshaped.shape[0], self.num_detections, self.num_classes
259
+ )
260
+ )
261
+ # print("Conf_class size: ", conf_class.size())
262
+
263
+ output_spotting = torch.cat((conf_pred, conf_class), dim=-1)
264
+ # print("Output_spotting size: ", output_spotting.size())
265
+
266
+ return output_spotting
@@ -0,0 +1,210 @@
1
+ import torch
2
+ from torch import nn
3
+ from opensportslib.core.utils.data import batch_tensor, unbatch_tensor
4
+
5
+ def build_neck(cfg, default_args=None):
6
+ if cfg.type == "MV_Aggregate":
7
+ neck = MVAggregate(
8
+ agr_type=cfg.agr_type,
9
+ model=default_args["model"],
10
+ feat_dim=default_args["feat_dim"],
11
+ lifting_net=default_args["lifting_net"] if "lifting_net" in default_args else nn.Sequential()
12
+ )
13
+
14
+ elif cfg.type == "TemporalAggregation":
15
+ neck = TemporalAggregation(
16
+ temporal_type=cfg.agr_type,
17
+ hidden_dim=cfg.hidden_dim,
18
+ window_size=default_args["window_size"],
19
+ dropout=cfg.dropout,
20
+ use_position_encoding=getattr(cfg, "use_position_encoding", False),
21
+ num_attention_heads=getattr(cfg, "num_attention_heads", 4),
22
+ lstm_dropout=getattr(cfg, "lstm_dropout", 0.1)
23
+ )
24
+ else:
25
+ raise ValueError(f"Unknown neck type: {cfg.type}")
26
+ return neck
27
+
28
+
29
+ class MVAggregate(nn.Module):
30
+ def __init__(self, agr_type, model, feat_dim, lifting_net=nn.Sequential()):
31
+ super().__init__()
32
+ self.agr_type = agr_type
33
+ self.model = model
34
+ self.feat_dim = feat_dim
35
+ self.lifting_net = lifting_net
36
+ print("Inside NECK BUILDER - AGR TYPE:", self.agr_type)
37
+
38
+ if self.agr_type == "max":
39
+ self.aggregation_model = ViewMaxAggregate(model=model, lifting_net=lifting_net)
40
+ elif self.agr_type == "mean":
41
+ self.aggregation_model = ViewAvgAggregate(model=model, lifting_net=lifting_net)
42
+ else:
43
+ # avg
44
+ self.aggregation_model = WeightedAggregate(model=model, feat_dim=feat_dim, lifting_net=lifting_net)
45
+
46
+ self.inter = nn.Sequential(
47
+ nn.LayerNorm(feat_dim),
48
+ nn.Linear(feat_dim, feat_dim),
49
+ nn.Linear(feat_dim, feat_dim),
50
+ )
51
+
52
+ def forward(self, mvimages):
53
+ pooled_view, attention = self.aggregation_model(mvimages)
54
+ inter = self.inter(pooled_view)
55
+ return inter, attention
56
+
57
+
58
+
59
+ class WeightedAggregate(nn.Module):
60
+ def __init__(self, model, feat_dim, lifting_net=nn.Sequential()):
61
+ super().__init__()
62
+ self.model = model
63
+ self.lifting_net = lifting_net
64
+ self.feature_dim = feat_dim
65
+
66
+ r1 = -1
67
+ r2 = 1
68
+ self.attention_weights = nn.Parameter((r1 - r2) * torch.rand(feat_dim, feat_dim) + r2)
69
+
70
+ self.normReLu = nn.Sequential(
71
+ nn.LayerNorm(feat_dim),
72
+ nn.ReLU()
73
+ )
74
+
75
+ self.relu = nn.ReLU()
76
+
77
+
78
+
79
+ def forward(self, mvimages):
80
+ B, V, C, D, H, W = mvimages.shape # Batch, Views, Channel, Depth, Height, Width
81
+ aux = self.lifting_net(unbatch_tensor(self.model(batch_tensor(mvimages, dim=1, squeeze=True)), B, dim=1, unsqueeze=True))
82
+
83
+
84
+ ##################### VIEW ATTENTION #####################
85
+
86
+ # S = source length
87
+ # N = batch size
88
+ # E = embedding dimension
89
+ # L = target length
90
+
91
+ aux = torch.matmul(aux, self.attention_weights)
92
+ # Dimension S, E for two views (2,512)
93
+
94
+ # Dimension N, S, E
95
+ aux_t = aux.permute(0, 2, 1)
96
+
97
+ prod = torch.bmm(aux, aux_t)
98
+ relu_res = self.relu(prod)
99
+
100
+ aux_sum = torch.sum(torch.reshape(relu_res, (B, V*V)).T, dim=0).unsqueeze(0)
101
+ final_attention_weights = torch.div(torch.reshape(relu_res, (B, V*V)).T, aux_sum.squeeze(0))
102
+ final_attention_weights = final_attention_weights.T
103
+
104
+ final_attention_weights = torch.reshape(final_attention_weights, (B, V, V))
105
+
106
+ final_attention_weights = torch.sum(final_attention_weights, 1)
107
+
108
+ output = torch.mul(aux.squeeze(), final_attention_weights.unsqueeze(-1))
109
+
110
+ output = torch.sum(output, 1)
111
+
112
+ return output.squeeze(), final_attention_weights
113
+
114
+
115
+ class ViewMaxAggregate(nn.Module):
116
+ def __init__(self, model, lifting_net=nn.Sequential()):
117
+ super().__init__()
118
+ self.model = model
119
+ self.lifting_net = lifting_net
120
+
121
+ def forward(self, mvimages):
122
+ B, V, C, D, H, W = mvimages.shape # Batch, Views, Channel, Depth, Height, Width
123
+ aux = self.lifting_net(unbatch_tensor(self.model(batch_tensor(mvimages, dim=1, squeeze=True)), B, dim=1, unsqueeze=True))
124
+ pooled_view = torch.max(aux, dim=1)[0]
125
+ return pooled_view.squeeze(), aux
126
+
127
+
128
+ class ViewAvgAggregate(nn.Module):
129
+ def __init__(self, model, lifting_net=nn.Sequential()):
130
+ super().__init__()
131
+ self.model = model
132
+ self.lifting_net = lifting_net
133
+
134
+ def forward(self, mvimages):
135
+ B, V, C, D, H, W = mvimages.shape # Batch, Views, Channel, Depth, Height, Width
136
+ aux = self.lifting_net(unbatch_tensor(self.model(batch_tensor(mvimages, dim=1, squeeze=True)), B, dim=1, unsqueeze=True))
137
+ pooled_view = torch.mean(aux, dim=1)
138
+ return pooled_view.squeeze(), aux
139
+
140
+ class TemporalAggregation(nn.Module):
141
+ def __init__(self, temporal_type, hidden_dim, window_size, dropout=0.1,
142
+ use_position_encoding=False, num_attention_heads=4, lstm_dropout=0.3):
143
+ super().__init__()
144
+
145
+ self.num_attention_heads = num_attention_heads
146
+ self.temporal_type = temporal_type
147
+ self.hidden_dim = hidden_dim
148
+ self.feat_dim = hidden_dim * 2 if temporal_type == "bilstm" else hidden_dim
149
+ self.use_position_encoding = use_position_encoding
150
+ self.lstm_dropout = lstm_dropout
151
+
152
+ # learnable temporal position encoding (only used when explicitly enabled)
153
+ if self.use_position_encoding:
154
+ self.temporal_position_encoding = nn.Parameter(
155
+ torch.randn(1, window_size, hidden_dim) * 0.02
156
+ )
157
+
158
+ # build temporal module
159
+ self.temporal = self._build_temporal_module(temporal_type, hidden_dim, dropout)
160
+
161
+ def _build_temporal_module(self, temporal_type, hidden_dim, dropout):
162
+ if temporal_type == 'bilstm':
163
+ return nn.LSTM(
164
+ hidden_dim, hidden_dim, num_layers=2,
165
+ batch_first=True, bidirectional=True, dropout=self.lstm_dropout
166
+ )
167
+
168
+ elif temporal_type == 'tcn':
169
+ return nn.Sequential(
170
+ nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
171
+ nn.ReLU(),
172
+ nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
173
+ )
174
+
175
+ elif temporal_type == 'attention':
176
+ return nn.MultiheadAttention(
177
+ hidden_dim, num_heads=self.num_attention_heads,
178
+ dropout=dropout, batch_first=True
179
+ )
180
+
181
+ else: # avgpool, maxpool
182
+ return None
183
+
184
+ def forward(self, x):
185
+ seq_len = x.size(1)
186
+
187
+ if self.use_position_encoding:
188
+ x = x + self.temporal_position_encoding[:, :seq_len, :]
189
+
190
+ if self.temporal_type == 'avgpool':
191
+ x = torch.mean(x, dim=1)
192
+
193
+ elif self.temporal_type == 'maxpool':
194
+ x = torch.max(x, dim=1)[0]
195
+
196
+ elif self.temporal_type == 'tcn':
197
+ x = x.permute(0, 2, 1)
198
+ x = self.temporal(x)
199
+ x = x.permute(0, 2, 1)
200
+ x = torch.max(x, dim=1)[0]
201
+
202
+ elif self.temporal_type == 'attention':
203
+ x, _ = self.temporal(x, x, x)
204
+ x = torch.max(x, dim=1)[0]
205
+
206
+ elif self.temporal_type == 'bilstm':
207
+ lstm_out, _ = self.temporal(x)
208
+ x = torch.max(lstm_out, dim=1)[0]
209
+
210
+ return x
@@ -0,0 +1,176 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+ import abc
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+
36
+ class ABCModel:
37
+
38
+ @abc.abstractmethod
39
+ def get_optimizer(self, opt_args):
40
+ raise NotImplementedError()
41
+
42
+ @abc.abstractmethod
43
+ def epoch(self, loader, **kwargs):
44
+ raise NotImplementedError()
45
+
46
+ @abc.abstractmethod
47
+ def predict(self, seq):
48
+ raise NotImplementedError()
49
+
50
+ @abc.abstractmethod
51
+ def state_dict(self):
52
+ raise NotImplementedError()
53
+
54
+ @abc.abstractmethod
55
+ def load(self, state_dict):
56
+ raise NotImplementedError()
57
+
58
+
59
+ class BaseRGBModel(ABCModel):
60
+
61
+ def get_optimizer(self, opt_args):
62
+ return torch.optim.AdamW(self._get_params(), **opt_args), (
63
+ torch.cuda.amp.GradScaler() if self.device == "cuda" else None
64
+ )
65
+
66
+ """ Assume there is a self._model """
67
+
68
+ def _get_params(self):
69
+ return list(self._model.parameters())
70
+
71
+ def state_dict(self):
72
+ if isinstance(self._model, nn.DataParallel):
73
+ return self._model.module.state_dict()
74
+ return self._model.state_dict()
75
+
76
+ def load(self, state_dict):
77
+ if isinstance(self._model, nn.DataParallel):
78
+ self._model.module.load_state_dict(state_dict)
79
+ else:
80
+ self._model.load_state_dict(state_dict)
81
+
82
+
83
+ def step(optimizer, scaler, loss, lr_scheduler=None, backward_only=False):
84
+ """
85
+ Perform a backward pass, and optionally update the model parameters and learning rate scheduler.
86
+
87
+ Args:
88
+ optimizer (torch.optim.Optimizer): The optimizer to update model parameters.
89
+ scaler (torch.cuda.amp.GradScaler): The gradient scaler for mixed precision training.
90
+ loss (torch.Tensor): The computed loss to backpropagate.
91
+ lr_scheduler : The learning rate scheduler.
92
+ Default: None.
93
+ backward_only (bool): If True, only perform the backward pass.
94
+ Default: False.
95
+ """
96
+ if scaler is None:
97
+ loss.backward()
98
+ else:
99
+ scaler.scale(loss).backward()
100
+
101
+ if not backward_only:
102
+ if scaler is None:
103
+ optimizer.step()
104
+ else:
105
+ scaler.step(optimizer)
106
+ scaler.update()
107
+ if lr_scheduler is not None:
108
+ lr_scheduler.step()
109
+ optimizer.zero_grad()
110
+
111
+
112
+ class SingleStageGRU(nn.Module):
113
+
114
+ def __init__(self, in_dim, hidden_dim, out_dim, num_layers=5):
115
+ super(SingleStageGRU, self).__init__()
116
+ self.backbone = nn.GRU(
117
+ in_dim,
118
+ hidden_dim,
119
+ num_layers=num_layers,
120
+ batch_first=True,
121
+ bidirectional=True,
122
+ )
123
+ self.fc_out = nn.Sequential(
124
+ nn.BatchNorm1d(2 * hidden_dim),
125
+ nn.Dropout(),
126
+ nn.Linear(2 * hidden_dim, out_dim),
127
+ )
128
+
129
+ def forward(self, x):
130
+ batch_size, clip_len, _ = x.shape
131
+ x, _ = self.backbone(x)
132
+ x = self.fc_out(x.reshape(-1, x.shape[-1]))
133
+ return x.view(batch_size, clip_len, -1)
134
+
135
+
136
+ class SingleStageTCN(nn.Module):
137
+
138
+ class DilatedResidualLayer(nn.Module):
139
+ def __init__(self, dilation, in_channels, out_channels):
140
+ super(SingleStageTCN.DilatedResidualLayer, self).__init__()
141
+ self.conv_dilated = nn.Conv1d(
142
+ in_channels, out_channels, 3, padding=dilation, dilation=dilation
143
+ )
144
+ self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
145
+ self.dropout = nn.Dropout()
146
+
147
+ def forward(self, x, mask):
148
+ out = F.relu(self.conv_dilated(x))
149
+ out = self.conv_1x1(out)
150
+ out = self.dropout(out)
151
+ return (x + out) * mask[:, 0:1, :]
152
+
153
+ def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dilate):
154
+ super(SingleStageTCN, self).__init__()
155
+ self.conv_1x1 = nn.Conv1d(in_dim, hidden_dim, 1)
156
+ self.layers = nn.ModuleList(
157
+ [
158
+ SingleStageTCN.DilatedResidualLayer(
159
+ 2**i if dilate else 1, hidden_dim, hidden_dim
160
+ )
161
+ for i in range(num_layers)
162
+ ]
163
+ )
164
+ self.conv_out = nn.Conv1d(hidden_dim, out_dim, 1)
165
+
166
+ def forward(self, x, m=None):
167
+ batch_size, clip_len, _ = x.shape
168
+ if m is None:
169
+ m = torch.ones((batch_size, 1, clip_len), device=x.device)
170
+ else:
171
+ m = m.permute(0, 2, 1)
172
+ x = self.conv_1x1(x.permute(0, 2, 1))
173
+ for layer in self.layers:
174
+ x = layer(x, m)
175
+ x = self.conv_out(x) * m[:, 0:1, :]
176
+ return x.permute(0, 2, 1)
File without changes