opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,390 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+ # Adapted from: https://github.com/ChinaYi/ASFormer
31
+ # Source license:
32
+ """
33
+ MIT License
34
+
35
+ Copyright (c) 2021 ChinaYi
36
+
37
+ Permission is hereby granted, free of charge, to any person obtaining a copy
38
+ of this software and associated documentation files (the "Software"), to deal
39
+ in the Software without restriction, including without limitation the rights
40
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
41
+ copies of the Software, and to permit persons to whom the Software is
42
+ furnished to do so, subject to the following conditions:
43
+
44
+ The above copyright notice and this permission notice shall be included in all
45
+ copies or substantial portions of the Software.
46
+
47
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
48
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
49
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
50
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
51
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
52
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
53
+ SOFTWARE.
54
+ """
55
+
56
+ import torch
57
+ import torch.nn as nn
58
+ import torch.nn.functional as F
59
+
60
+ import copy
61
+ import numpy as np
62
+ import math
63
+
64
+
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+
68
+ def exponential_descrease(idx_decoder, p=3):
69
+ return math.exp(-p*idx_decoder)
70
+
71
+
72
+ class AttentionHelper(nn.Module):
73
+ def __init__(self):
74
+ super(AttentionHelper, self).__init__()
75
+ self.softmax = nn.Softmax(dim=-1)
76
+
77
+
78
+ def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask):
79
+ '''
80
+ scalar dot attention.
81
+ :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length)
82
+ :param proj_key: shape of (B, C, L)
83
+ :param proj_val: shape of (B, C, L)
84
+ :param padding_mask: shape of (B, C, L)
85
+ :return: attention value of shape (B, C, L)
86
+ '''
87
+ m, c1, l1 = proj_query.shape
88
+ m, c2, l2 = proj_key.shape
89
+
90
+ assert c1 == c2
91
+
92
+ energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key) # out of shape (B, L1, L2)
93
+ attention = energy / np.sqrt(c1)
94
+ attention = attention + torch.log(padding_mask + 1e-6) # mask the zero paddings. log(1e-6) for zero paddings
95
+ attention = self.softmax(attention)
96
+ attention = attention * padding_mask
97
+ attention = attention.permute(0,2,1)
98
+ out = torch.bmm(proj_val, attention)
99
+ return out, attention
100
+
101
+ class AttLayer(nn.Module):
102
+ def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): # r1 = r2
103
+ super(AttLayer, self).__init__()
104
+
105
+ self.query_conv = nn.Conv1d(in_channels=q_dim, out_channels=q_dim // r1, kernel_size=1)
106
+ self.key_conv = nn.Conv1d(in_channels=k_dim, out_channels=k_dim // r2, kernel_size=1)
107
+ self.value_conv = nn.Conv1d(in_channels=v_dim, out_channels=v_dim // r3, kernel_size=1)
108
+
109
+ self.conv_out = nn.Conv1d(in_channels=v_dim // r3, out_channels=v_dim, kernel_size=1)
110
+
111
+ self.bl = bl
112
+ self.stage = stage
113
+ self.att_type = att_type
114
+ assert self.att_type in ['normal_att', 'block_att', 'sliding_att']
115
+ assert self.stage in ['encoder','decoder']
116
+
117
+ self.att_helper = AttentionHelper()
118
+ self.window_mask = self.construct_window_mask()
119
+
120
+
121
+ def construct_window_mask(self):
122
+ '''
123
+ construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention
124
+ '''
125
+ window_mask = torch.zeros((1, self.bl, self.bl + 2* (self.bl //2)))
126
+ for i in range(self.bl):
127
+ window_mask[:, i, i:i+self.bl] = 1
128
+ return window_mask.to(device)
129
+
130
+ def forward(self, x1, x2, mask):
131
+ # x1 from the encoder
132
+ # x2 from the decoder
133
+
134
+ query = self.query_conv(x1)
135
+ key = self.key_conv(x1)
136
+
137
+ if self.stage == 'decoder':
138
+ assert x2 is not None
139
+ value = self.value_conv(x2)
140
+ else:
141
+ value = self.value_conv(x1)
142
+
143
+ if self.att_type == 'normal_att':
144
+ return self._normal_self_att(query, key, value, mask)
145
+ elif self.att_type == 'block_att':
146
+ return self._block_wise_self_att(query, key, value, mask)
147
+ elif self.att_type == 'sliding_att':
148
+ return self._sliding_window_self_att(query, key, value, mask)
149
+
150
+ # # NOTE: verify that results are consistent regardless of batchsize
151
+ # tmp = self._sliding_window_self_att(query, key, value, mask)
152
+ # tmp2 = []
153
+ # for i in range(query.shape[0]):
154
+ # tmp2.append(
155
+ # self._sliding_window_self_att(
156
+ # query[[i]], key[[i]], value[[i]], mask[[i]]))
157
+ # tmp2 = torch.cat(tmp2)
158
+ # print(tmp.shape, (tmp - tmp2).abs().mean().item())
159
+ # return tmp
160
+
161
+ def _normal_self_att(self,q,k,v, mask):
162
+ m_batchsize, c1, L = q.size()
163
+ _,c2,L = k.size()
164
+ _,c3,L = v.size()
165
+ padding_mask = torch.ones((m_batchsize, 1, L)).to(device) * mask[:,0:1,:]
166
+ output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
167
+ output = self.conv_out(F.relu(output))
168
+ output = output[:, :, 0:L]
169
+ return output * mask[:, 0:1, :]
170
+
171
+ def _block_wise_self_att(self, q,k,v, mask):
172
+ m_batchsize, c1, L = q.size()
173
+ _,c2,L = k.size()
174
+ _,c3,L = v.size()
175
+
176
+ nb = L // self.bl
177
+ if L % self.bl != 0:
178
+ q = torch.cat([q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], dim=-1)
179
+ k = torch.cat([k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], dim=-1)
180
+ v = torch.cat([v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], dim=-1)
181
+ nb += 1
182
+
183
+ padding_mask = torch.cat([torch.ones((m_batchsize, 1, L)).to(device) * mask[:,0:1,:], torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device)],dim=-1)
184
+
185
+ q = q.reshape(m_batchsize, c1, nb, self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb, c1, self.bl)
186
+ padding_mask = padding_mask.reshape(m_batchsize, 1, nb, self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb,1, self.bl)
187
+ k = k.reshape(m_batchsize, c2, nb, self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb, c2, self.bl)
188
+ v = v.reshape(m_batchsize, c3, nb, self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb, c3, self.bl)
189
+
190
+ output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
191
+ output = self.conv_out(F.relu(output))
192
+
193
+ output = output.reshape(m_batchsize, nb, c3, self.bl).permute(0, 2, 1, 3).reshape(m_batchsize, c3, nb * self.bl)
194
+ output = output[:, :, 0:L]
195
+ return output * mask[:, 0:1, :]
196
+
197
+ def _sliding_window_self_att(self, q,k,v, mask):
198
+ m_batchsize, c1, L = q.size()
199
+ _, c2, _ = k.size()
200
+ _, c3, _ = v.size()
201
+
202
+ # assert m_batchsize == 1 # currently, we only accept input with batch size 1
203
+ # padding zeros for the last segment
204
+ nb = L // self.bl
205
+ if L % self.bl != 0:
206
+ q = torch.cat([q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], dim=-1)
207
+ k = torch.cat([k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], dim=-1)
208
+ v = torch.cat([v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], dim=-1)
209
+ nb += 1
210
+ padding_mask = torch.cat([torch.ones((m_batchsize, 1, L)).to(device) * mask[:,0:1,:], torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device)],dim=-1)
211
+
212
+ # sliding window approach, by splitting query_proj and key_proj into shape (c1, l) x (c1, 2l)
213
+ # sliding window for query_proj: reshape
214
+ # NOTE: fixes issue with batchsize > 1
215
+ q = q.reshape(m_batchsize, c1, nb, self.bl).permute(2, 0, 1, 3).reshape(m_batchsize * nb, c1, self.bl)
216
+
217
+ # BUG: original buggy version, produces inconsistent results
218
+ # q = q.reshape(m_batchsize, c1, nb, self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb, c1, self.bl)
219
+
220
+ # sliding window approach for key_proj
221
+ # 1. add paddings at the start and end
222
+ k = torch.cat([torch.zeros(m_batchsize, c2, self.bl // 2).to(device), k, torch.zeros(m_batchsize, c2, self.bl // 2).to(device)], dim=-1)
223
+ v = torch.cat([torch.zeros(m_batchsize, c3, self.bl // 2).to(device), v, torch.zeros(m_batchsize, c3, self.bl // 2).to(device)], dim=-1)
224
+ padding_mask = torch.cat([torch.zeros(m_batchsize, 1, self.bl // 2).to(device), padding_mask, torch.zeros(m_batchsize, 1, self.bl // 2).to(device)], dim=-1)
225
+ # 2. reshape key_proj of shape (m_batchsize*nb, c1, 2*self.bl)
226
+ k = torch.cat([k[:,:, i*self.bl:(i+1)*self.bl+(self.bl//2)*2] for i in range(nb)], dim=0) # special case when self.bl = 1
227
+ v = torch.cat([v[:,:, i*self.bl:(i+1)*self.bl+(self.bl//2)*2] for i in range(nb)], dim=0)
228
+ # 3. construct window mask of shape (1, l, 2l), and use it to generate final mask
229
+ padding_mask = torch.cat([padding_mask[:,:, i*self.bl:(i+1)*self.bl+(self.bl//2)*2] for i in range(nb)], dim=0) # of shape (m*nb, 1, 2l)
230
+ final_mask = self.window_mask.repeat(m_batchsize * nb, 1, 1) * padding_mask
231
+
232
+ output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask)
233
+ output = self.conv_out(F.relu(output))
234
+
235
+ # NOTE: fixes issue with batchsize > 1
236
+ output = output.reshape(nb, m_batchsize, -1, self.bl).permute(1, 2, 0, 3).reshape(m_batchsize, -1, nb * self.bl)
237
+
238
+ # BUG: original buggy version, produces inconsistent results
239
+ # output = output.reshape(m_batchsize, nb, -1, self.bl).permute(0, 2, 1, 3).reshape(m_batchsize, -1, nb * self.bl)
240
+
241
+ output = output[:, :, 0:L]
242
+ return output * mask[:, 0:1, :]
243
+
244
+
245
+ class MultiHeadAttLayer(nn.Module):
246
+ def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head):
247
+ super(MultiHeadAttLayer, self).__init__()
248
+ # assert v_dim % num_head == 0
249
+ self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1)
250
+ self.layers = nn.ModuleList(
251
+ [copy.deepcopy(AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)) for i in range(num_head)])
252
+ self.dropout = nn.Dropout(p=0.5)
253
+
254
+ def forward(self, x1, x2, mask):
255
+ out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1)
256
+ out = self.conv_out(self.dropout(out))
257
+ return out
258
+
259
+
260
+ class ConvFeedForward(nn.Module):
261
+ def __init__(self, dilation, in_channels, out_channels):
262
+ super(ConvFeedForward, self).__init__()
263
+ self.layer = nn.Sequential(
264
+ nn.Conv1d(in_channels, out_channels, 3, padding=dilation, dilation=dilation),
265
+ nn.ReLU()
266
+ )
267
+
268
+ def forward(self, x):
269
+ return self.layer(x)
270
+
271
+
272
+ class FCFeedForward(nn.Module):
273
+ def __init__(self, in_channels, out_channels):
274
+ super(FCFeedForward, self).__init__()
275
+ self.layer = nn.Sequential(
276
+ nn.Conv1d(in_channels, out_channels, 1), # conv1d equals fc
277
+ nn.ReLU(),
278
+ nn.Dropout(),
279
+ nn.Conv1d(out_channels, out_channels, 1)
280
+ )
281
+
282
+ def forward(self, x):
283
+ return self.layer(x)
284
+
285
+
286
+ class AttModule(nn.Module):
287
+ def __init__(self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha):
288
+ super(AttModule, self).__init__()
289
+ self.feed_forward = ConvFeedForward(dilation, in_channels, out_channels)
290
+ self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False)
291
+ self.att_layer = AttLayer(in_channels, in_channels, out_channels, r1, r1, r2, dilation, att_type=att_type, stage=stage) # dilation
292
+ self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
293
+ self.dropout = nn.Dropout()
294
+ self.alpha = alpha
295
+
296
+ def forward(self, x, f, mask):
297
+ out = self.feed_forward(x)
298
+ out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out
299
+ out = self.conv_1x1(out)
300
+ out = self.dropout(out)
301
+ return (x + out) * mask[:, 0:1, :]
302
+
303
+
304
+ class PositionalEncoding(nn.Module):
305
+ "Implement the PE function."
306
+
307
+ def __init__(self, d_model, max_len=10000):
308
+ super(PositionalEncoding, self).__init__()
309
+ # Compute the positional encodings once in log space.
310
+ pe = torch.zeros(max_len, d_model)
311
+ position = torch.arange(0, max_len).unsqueeze(1)
312
+ div_term = torch.exp(torch.arange(0, d_model, 2) *
313
+ -(math.log(10000.0) / d_model))
314
+ pe[:, 0::2] = torch.sin(position * div_term)
315
+ pe[:, 1::2] = torch.cos(position * div_term)
316
+ pe = pe.unsqueeze(0).permute(0,2,1) # of shape (1, d_model, l)
317
+ self.pe = nn.Parameter(pe, requires_grad=True)
318
+ # self.register_buffer('pe', pe)
319
+
320
+ def forward(self, x):
321
+ return x + self.pe[:, :, 0:x.shape[2]]
322
+
323
+ class Encoder(nn.Module):
324
+ def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, channel_masking_rate, att_type, alpha):
325
+ super(Encoder, self).__init__()
326
+ self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) # fc layer
327
+ self.layers = nn.ModuleList(
328
+ [AttModule(2 ** i, num_f_maps, num_f_maps, r1, r2, att_type, 'encoder', alpha) for i in # 2**i
329
+ range(num_layers)])
330
+
331
+ self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
332
+ self.dropout = nn.Dropout2d(p=channel_masking_rate)
333
+ self.channel_masking_rate = channel_masking_rate
334
+
335
+ def forward(self, x, mask):
336
+ '''
337
+ :param x: (N, C, L)
338
+ :param mask:
339
+ :return:
340
+ '''
341
+
342
+ if self.channel_masking_rate > 0:
343
+ x = x.unsqueeze(2)
344
+ x = self.dropout(x)
345
+ x = x.squeeze(2)
346
+
347
+ feature = self.conv_1x1(x)
348
+ for layer in self.layers:
349
+ feature = layer(feature, None, mask)
350
+
351
+ out = self.conv_out(feature) * mask[:, 0:1, :]
352
+
353
+ return out, feature
354
+
355
+
356
+ class Decoder(nn.Module):
357
+ def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha):
358
+ super(Decoder, self).__init__()# self.position_en = PositionalEncoding(d_model=num_f_maps)
359
+ self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)
360
+ self.layers = nn.ModuleList(
361
+ [AttModule(2 ** i, num_f_maps, num_f_maps, r1, r2, att_type, 'decoder', alpha) for i in # 2 ** i
362
+ range(num_layers)])
363
+ self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
364
+
365
+ def forward(self, x, fencoder, mask):
366
+
367
+ feature = self.conv_1x1(x)
368
+ for layer in self.layers:
369
+ feature = layer(feature, fencoder, mask)
370
+
371
+ out = self.conv_out(feature) * mask[:, 0:1, :]
372
+
373
+ return out, feature
374
+
375
+ class MyTransformer(nn.Module):
376
+ def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, num_classes, channel_masking_rate):
377
+ super(MyTransformer, self).__init__()
378
+ self.encoder = Encoder(num_layers, r1, r2, num_f_maps, input_dim, num_classes, channel_masking_rate, att_type='sliding_att', alpha=1)
379
+ self.decoders = nn.ModuleList([copy.deepcopy(Decoder(num_layers, r1, r2, num_f_maps, num_classes, num_classes, att_type='sliding_att', alpha=exponential_descrease(s))) for s in range(num_decoders)]) # num_decoders
380
+
381
+
382
+ def forward(self, x, mask):
383
+ out, feature = self.encoder(x, mask)
384
+ outputs = out.unsqueeze(0)
385
+
386
+ for decoder in self.decoders:
387
+ out, feature = decoder(F.softmax(out, dim=1) * mask[:, 0:1, :], feature* mask[:, 0:1, :], mask)
388
+ outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
389
+
390
+ return outputs
@@ -0,0 +1,74 @@
1
+ """
2
+ From "A Context-Aware Loss Function for Action Spotting in Soccer Videos"
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class ContextAwareWeights:
11
+
12
+ def __init__(self, k1=2, k2=1, k3=2, k4=2, hit_radius=0.1, miss_radius=0.9):
13
+ n = k1 + k2 + k3 + k4
14
+ mul_w = np.ones(n)
15
+ add_w = np.ones(n)
16
+ radius = np.full(n, miss_radius)
17
+ for i in range(n):
18
+ if i < k1:
19
+ mul_w[i] = (k1 - i) / k1
20
+ elif i < k1 + k2:
21
+ mul_w[i] = 0
22
+ elif i < k1 + k2 + k3:
23
+ mul_w[i] = ((i - k1 - k2) - k3) / k3
24
+ add_w[i] = (i - k1 - k2) / k3
25
+ radius[i] = 1. - hit_radius
26
+ else:
27
+ mul_w[i] = (i - k1 - k2 - k3) / k4
28
+ self._w = np.stack([mul_w, add_w, radius], axis=1)
29
+ self._offset = k1 + k2
30
+
31
+ @property
32
+ def weights(self):
33
+ return self._w
34
+
35
+ @property
36
+ def offset(self):
37
+ return self._offset
38
+
39
+ def __len__(self):
40
+ return self._w.shape[0]
41
+
42
+
43
+ CALF_ERROR_FLAG = True
44
+
45
+
46
+ def set_calf_error_flag():
47
+ global CALF_ERROR_FLAG
48
+ CALF_ERROR_FLAG = 1
49
+
50
+
51
+ def get_calf(pred, weights):
52
+ pred_scores = F.softmax(pred, dim=2) # (N, L, C)
53
+ cl = -torch.log(
54
+ weights[:, :, :, 1] - pred_scores[:, :, 1:] * weights[:, :, :, 0]
55
+ ) + torch.log(weights[:, :, :, 2])
56
+ cl = torch.max(torch.zeros_like(cl), cl)
57
+
58
+ global CALF_ERROR_FLAG
59
+ if CALF_ERROR_FLAG:
60
+ tmp = torch.sum(cl)
61
+ if torch.isinf(tmp):
62
+ print('Found Inf in CALF. Supressing future errors.')
63
+ CALF_ERROR_FLAG = False
64
+ if torch.isnan(tmp):
65
+ print('Found NaN in CALF. Supressing future errors.')
66
+ CALF_ERROR_FLAG = False
67
+ return torch.mean(cl)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ c = ContextAwareWeights()
72
+ print(c.weights)
73
+ print('All 1:', -(c.weights[:, 1] - np.ones(len(c)) * c.weights[:, 0]))
74
+ print('All 0:', -(c.weights[:, 1] - np.zeros(len(c)) * c.weights[:, 0]))
@@ -0,0 +1,112 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+ # Adapted PyTorch GSM implementation by Swathikiran Sudhakaran, https://github.com/swathikirans/GSM
31
+ # Original license for GSM
32
+ """
33
+ BSD 2-Clause License for GSM
34
+
35
+ Copyright (c) 2019, FBK
36
+ All rights reserved.
37
+
38
+ Redistribution and use in source and binary forms, with or without
39
+ modification, are permitted provided that the following conditions are met:
40
+
41
+ * Redistributions of source code must retain the above copyright notice, this
42
+ list of conditions and the following disclaimer.
43
+
44
+ * Redistributions in binary form must reproduce the above copyright notice,
45
+ this list of conditions and the following disclaimer in the documentation
46
+ and/or other materials provided with the distribution.
47
+
48
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
49
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
50
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
51
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
52
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
53
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
54
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
55
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
56
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
57
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
58
+ """
59
+
60
+ import torch
61
+ from torch import nn
62
+ from torch.cuda import FloatTensor as ftens
63
+
64
+
65
+ class _GSM(nn.Module):
66
+ def __init__(self, fPlane, num_segments=3):
67
+ super(_GSM, self).__init__()
68
+
69
+ self.conv3D = nn.Conv3d(fPlane, 2, (3, 3, 3), stride=1,
70
+ padding=(1, 1, 1), groups=2)
71
+ nn.init.constant_(self.conv3D.weight, 0)
72
+ nn.init.constant_(self.conv3D.bias, 0)
73
+ self.tanh = nn.Tanh()
74
+ self.fPlane = fPlane
75
+ self.num_segments = num_segments
76
+ self.bn = nn.BatchNorm3d(num_features=fPlane)
77
+ self.relu = nn.ReLU()
78
+
79
+ def lshift_zeroPad(self, x):
80
+ return torch.cat((x[:,:,1:], ftens(x.size(0), x.size(1), 1, x.size(3), x.size(4)).fill_(0)), dim=2)
81
+
82
+ def rshift_zeroPad(self, x):
83
+ return torch.cat((ftens(x.size(0), x.size(1), 1, x.size(3), x.size(4)).fill_(0), x[:,:,:-1]), dim=2)
84
+
85
+ def forward(self, x):
86
+ batchSize = x.size(0) // self.num_segments
87
+ shape = x.size(1), x.size(2), x.size(3)
88
+ assert shape[0] == self.fPlane
89
+ x = x.view(batchSize, self.num_segments, *shape).permute(0, 2, 1, 3, 4).contiguous()
90
+ x_bn = self.bn(x)
91
+ x_bn_relu = self.relu(x_bn)
92
+ gate = self.tanh(self.conv3D(x_bn_relu))
93
+ gate_group1 = gate[:, 0].unsqueeze(1)
94
+ gate_group2 = gate[:, 1].unsqueeze(1)
95
+ x_group1 = x[:, :self.fPlane // 2]
96
+ x_group2 = x[:, self.fPlane // 2:]
97
+ y_group1 = gate_group1 * x_group1
98
+ y_group2 = gate_group2 * x_group2
99
+
100
+ r_group1 = x_group1 - y_group1
101
+ r_group2 = x_group2 - y_group2
102
+
103
+ y_group1 = self.lshift_zeroPad(y_group1) + r_group1
104
+ y_group2 = self.rshift_zeroPad(y_group2) + r_group2
105
+
106
+ y_group1 = y_group1.view(batchSize, 2, self.fPlane // 4, self.num_segments, *shape[1:]).permute(0, 2, 1, 3, 4, 5)
107
+ y_group2 = y_group2.view(batchSize, 2, self.fPlane // 4, self.num_segments, *shape[1:]).permute(0, 2, 1, 3, 4, 5)
108
+
109
+ y = torch.cat((y_group1.contiguous().view(batchSize, self.fPlane//2, self.num_segments, *shape[1:]),
110
+ y_group2.contiguous().view(batchSize, self.fPlane//2, self.num_segments, *shape[1:])), dim=1)
111
+
112
+ return y.permute(0, 2, 1, 3, 4).contiguous().view(batchSize*self.num_segments, *shape)