linmult 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
linmult-1.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 fodorad
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
linmult-1.1.0/PKG-INFO ADDED
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.1
2
+ Name: linmult
3
+ Version: 1.1.0
4
+ Summary: General-purpose Multimodal Transformer with Linear Complexity Attention Mechanism
5
+ Author: fodorad
6
+ License: MIT
7
+ License-File: LICENSE
@@ -0,0 +1,143 @@
1
+ # LinMulT
2
+
3
+ General-purpose Multimodal Transformer with Linear Complexity Attention Mechanism.
4
+
5
+ # Setup
6
+
7
+ ### Environment
8
+ * Python 3.10+
9
+ * PyTorch and cuDNN 1.13.1+cu117
10
+
11
+ ### Install package with pip+git
12
+ ```
13
+ pip install -U git+https://github.com/fodorad/LinMulT.git
14
+ ```
15
+
16
+ ### Install package from repository root
17
+ ```
18
+ git clone https://github.com/fodorad/LinMulT
19
+ cd LinMulT
20
+ pip install -e .
21
+ pip install -U -r requirements.txt
22
+ ```
23
+
24
+ # Quick start
25
+ ### Example 1:
26
+ Simple transformer encoder with linear attention.
27
+ The forward pass is performed using an input sequence.
28
+ ```
29
+ import torch
30
+ from linmult import LinT
31
+
32
+ # input shape: (batch_size, time_dimension, feature_dimension)
33
+ x = torch.rand((32, 15, 1024), device='cuda')
34
+ model = LinT(input_modality_channels=1024, output_dim=5).cuda()
35
+ y_pred_seq = model(x)
36
+
37
+ # output shape: (batch_size, time_dimension, output_dimension)
38
+ assert y_pred_seq.size() == torch.Size([32, 15, 5])
39
+ ```
40
+
41
+ ### Example 2:
42
+ Multimodal Transformer with Linear Attention.
43
+ The forward pass is performed using 2 input sequences. Both input sequences have the same time dimension.
44
+ ```
45
+ import torch
46
+ from linmult import LinMulT
47
+
48
+ # input shape: (batch_size, time_dimension, feature_dimension)
49
+ x_1 = torch.rand((32, 15, 1024), device='cuda')
50
+ x_2 = torch.rand((32, 15, 160), device='cuda')
51
+ model = LinMulT(input_modality_channels=[1024, 160], output_dim=5).cuda()
52
+ y_pred_cls, y_pred_seq = model([x_1, x_2])
53
+
54
+ # 1. output shape: (batch_size, output_dimension)
55
+ assert y_pred_cls.size() == torch.Size([32, 5])
56
+
57
+ # 2. output shape: (batch_size, time_dimension, output_dimension)
58
+ assert y_pred_seq.size() == torch.Size([32, 15, 5])
59
+ ```
60
+
61
+ ### Example 3:
62
+ Multimodal Transformer with Linear Attention. The forward pass is performed using 3 input sequences with different time dimensions.
63
+ ```
64
+ import torch
65
+ from linmult import LinMulT
66
+
67
+ # input shape: (batch_size, time_dimension, feature_dimension)
68
+ x_1 = torch.rand((16, 1500, 25), device='cuda')
69
+ x_2 = torch.rand((16, 450, 35), device='cuda')
70
+ x_3 = torch.rand((16, 120, 768), device='cuda')
71
+ model = LinMulT(input_modality_channels=[25, 35, 768],
72
+ output_dim=5,
73
+ add_time_collapse=True,
74
+ add_self_attention_fusion=False).cuda()
75
+ y_pred_cls = model([x_1, x_2, x_3])
76
+
77
+ # output shape: (batch_size, output_dimension)
78
+ assert y_pred_cls.size() == torch.Size([16, 5])
79
+ ```
80
+
81
+ # Run tests
82
+ ```
83
+ python -m unittest
84
+ ```
85
+ # Similar projects using LinMulT
86
+
87
+ ### (2023) BlinkLinMulT
88
+ LinMulT is trained for blink presence detection and eye state recognition tasks.
89
+ Our results demonstrate comparable or superior performance compared to state-of-the-art models on 2 tasks, using 7 public benchmark databases.
90
+ * paper: BlinkLinMulT: Transformer-based Eye Blink Detection (accepted, available soon)
91
+ * code: https://github.com/fodorad/BlinkLinMulT
92
+
93
+ ### (2022) PersonalityLinMulT
94
+ LinMulT is trained for Big Five personality trait estimation using the First Impressions V2 dataset and sentiment estimation using the MOSI and MOSEI datasets.
95
+ * paper: Multimodal Sentiment and Personality Perception Under Speech: A Comparison of Transformer-based Architectures ([pdf](https://proceedings.mlr.press/v173/fodor22a/fodor22a.pdf), [website](https://proceedings.mlr.press/v173/fodor22a.html))
96
+ * code: https://github.com/fodorad/PersonalityLinMulT
97
+
98
+
99
+ # Citation - BibTex
100
+ If you found our research helpful or influential please consider citing:
101
+
102
+ ### (2023) LinMulT for blink presence detection and eye state recognition:
103
+ ```
104
+ @article{blinklinmult-fodor23,
105
+ title = {BlinkLinMulT: Transformer-based Eye Blink Detection},
106
+ author = {Fodor, {\'A}d{\'a}m and Fenech, Kristian and L{\H{o}}rincz, Andr{\'a}s},
107
+ journal = {...}
108
+ pages = {1--19},
109
+ year = {2023}
110
+ }
111
+ ```
112
+
113
+ ### (2022) LinMulT for personality trait and sentiment estimation:
114
+ ```
115
+ @InProceedings{pmlr-v173-fodor22a,
116
+ title = {Multimodal Sentiment and Personality Perception Under Speech: A Comparison of Transformer-based Architectures},
117
+ author = {Fodor, {\'A}d{\'a}m and Saboundji, Rachid R. and Jacques Junior, Julio C. S. and Escalera, Sergio and Gallardo-Pujol, David and L{\H{o}}rincz, Andr{\'a}s},
118
+ booktitle = {Understanding Social Behavior in Dyadic and Small Group Interactions},
119
+ pages = {218--241},
120
+ year = {2022},
121
+ editor = {Palmero, Cristina and Jacques Junior, Julio C. S. and Clapés, Albert and Guyon, Isabelle and Tu, Wei-Wei and Moeslund, Thomas B. and Escalera, Sergio},
122
+ volume = {173},
123
+ series = {Proceedings of Machine Learning Research},
124
+ month = {16 Oct},
125
+ publisher = {PMLR},
126
+ pdf = {https://proceedings.mlr.press/v173/fodor22a/fodor22a.pdf},
127
+ url = {https://proceedings.mlr.press/v173/fodor22a.html}
128
+ }
129
+ ```
130
+
131
+ # Acknowledgement
132
+ The code is inspired by the following two materials:
133
+
134
+ ### Multimodal Transformer:
135
+ * paper: Multimodal Transformer for Unaligned Multimodal Language Sequences ([1906.00295](https://arxiv.org/pdf/1906.00295.pdf))
136
+ * code: https://github.com/yaohungt/Multimodal-Transformer
137
+
138
+ ### Linear Attention:
139
+ * paper: Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention ([2006.16236](https://arxiv.org/pdf/2006.16236.pdf))
140
+ * code: https://github.com/idiap/fast-transformers
141
+
142
+ # Contact
143
+ * Ádám Fodor (foauaai@inf.elte.hu)
@@ -0,0 +1,2 @@
1
+ from linmult.models.LinT import LinT
2
+ from linmult.models.LinMulT import LinMulT
@@ -0,0 +1,238 @@
1
+ ##########################################################
2
+ # #
3
+ # Code is inspired by the following repositories: #
4
+ # https://github.com/yaohungt/Multimodal-Transformer #
5
+ # #
6
+ ##########################################################
7
+ import logging
8
+ from typing import Iterable
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ from linmult.models.transformer import TransformerEncoder
13
+
14
+ logging.basicConfig(level=logging.INFO,
15
+ format="%(asctime)s %(levelname)s %(message)s",
16
+ datefmt="%Y-%m-%d %H:%M:%S")
17
+
18
+
19
+ class LinMulT(nn.Module):
20
+
21
+ def __init__(self,
22
+ input_modality_channels: Iterable[int],
23
+ output_dim: int,
24
+ projected_modality_dim: int | list = 40, # d
25
+ number_of_heads: int = 8,
26
+ number_of_layers: int = 4, # D
27
+ embedding_dropout: float = 0.1,
28
+ cross_attention_dropout: float = 0.1,
29
+ self_attention_dropout: float = 0.0,
30
+ relu_dropout: float = 0.1,
31
+ residual_dropout: float = 0.1,
32
+ output_dropout: float = 0.1,
33
+ attention_mask: bool = True,
34
+ add_time_collapse: bool = False,
35
+ add_self_attention_fusion: bool = True,
36
+ add_projection_fusion: bool = True,
37
+ aggregation: str = 'meanpooling'):
38
+ """Construct a MulT model with linear attention mechanism.
39
+
40
+ Args:
41
+ aggregation (str | None): aggregation applied to the output sequence to get output_cls.
42
+ None - when add_time_collapse is True, aggregation is not used at all.
43
+ last - last timestep is used. Original MulT implementation.
44
+ cls - classification token is used.
45
+ meanpooling - mean is calculated over the T time dimension.
46
+ maxpooling - max is calculated over the T time dimension.
47
+ """
48
+ super().__init__()
49
+
50
+ if aggregation not in {None, 'last', 'cls', 'meanpooling', 'maxpooling'}:
51
+ raise Exception(f'Invalid aggregation {aggregation}.')
52
+
53
+ if add_time_collapse and add_self_attention_fusion:
54
+ raise Exception(f'These arguments cannot be True at the same time: {{add_time_collapse, add_self_attention_fusion}}')
55
+
56
+ self.input_modality_channels = input_modality_channels
57
+ self.output_dim = output_dim
58
+ self.number_of_modalities = len(self.input_modality_channels)
59
+
60
+ if isinstance(projected_modality_dim, int):
61
+ self.projected_modality_dim = [projected_modality_dim] * self.number_of_modalities
62
+ else: # list
63
+ if len(projected_modality_dim) != self.number_of_modalities:
64
+ raise Exception('Length of projected_modality_dim should be the number of modalities.')
65
+ self.projected_modality_dim = projected_modality_dim
66
+
67
+ self.number_of_heads = number_of_heads
68
+ self.number_of_layers = number_of_layers
69
+ self.embedding_dropout = embedding_dropout
70
+ self.cross_attention_dropout = cross_attention_dropout
71
+ self.self_attention_dropout = self_attention_dropout
72
+ self.relu_dropout = relu_dropout
73
+ self.residual_dropout = residual_dropout
74
+ self.output_dropout = output_dropout
75
+ self.attention_mask = attention_mask
76
+ self.add_time_collapse = add_time_collapse
77
+ self.add_self_attention_fusion = add_self_attention_fusion
78
+ self.add_projection_fusion = add_projection_fusion
79
+ self.aggregation = aggregation if not add_time_collapse else None
80
+ combined_dim = (self.number_of_modalities - 1) * torch.tensor(self.projected_modality_dim).sum()
81
+
82
+ # 1. Temporal Convolutional Layers
83
+ self.projectors = nn.ModuleList([
84
+ nn.Conv1d(input_modality_channels, projected_modality_dim, kernel_size=1, padding=0, bias=False)
85
+ for input_modality_channels, projected_modality_dim
86
+ in zip(self.input_modality_channels, self.projected_modality_dim)
87
+ ])
88
+
89
+ # 2. Crossmodal Attention Transformers
90
+ # e.g.: a, v, t modalities correspond to 0, 1, 2 indices
91
+ # Q -> a, K and V -> v, t: v t - 1 2
92
+ # Q -> v, K and V -> a, t: a t - 0 2
93
+ # Q -> t, K and V -> a, v: a v - 0 1
94
+ self.modality_indices = range(self.number_of_modalities)
95
+ self.crossmodal_transformers = nn.ModuleList([])
96
+ for target_index in self.modality_indices: # e.g. target_index = 0
97
+ input_indices = [ind for ind in self.modality_indices if ind != target_index] # e.g. input_indices = [1, 2]
98
+ self.crossmodal_transformers.append(
99
+ nn.ModuleList([
100
+ self.create_transformer(modality_index=input_index, attention_type='cross')
101
+ for input_index in input_indices
102
+ ])
103
+ )
104
+
105
+ # 3. Self Attention Transformers
106
+ self.self_attention_transformers = nn.ModuleList([
107
+ self.create_transformer(modality_index=target_index, attention_type='self', layers=3)
108
+ for target_index in self.modality_indices
109
+ ])
110
+
111
+ # 4. Self Attention Fusion Transformer
112
+ if self.add_self_attention_fusion:
113
+ self.self_attention_fusion_transformer = self.create_fusion_transformer()
114
+
115
+ if self.add_projection_fusion:
116
+ self.projection_1 = nn.Linear(combined_dim, combined_dim)
117
+ self.projection_2 = nn.Linear(combined_dim, combined_dim)
118
+
119
+ # 5. Sequence Head & Aggregation
120
+ self.out_layer = nn.Linear(combined_dim, self.output_dim) # (B, T, output_dim) or (B, output_dim)
121
+
122
+ def create_transformer(self, modality_index, attention_type: str, layers=-1):
123
+ if attention_type == 'cross': # Crossmodal Attention Transformer
124
+ embedding_dim = self.projected_modality_dim[modality_index]
125
+ attention_dropout = self.cross_attention_dropout
126
+ else: # Self Attention Transformer
127
+ embedding_dim = (self.number_of_modalities - 1) * self.projected_modality_dim[modality_index]
128
+ attention_dropout = self.self_attention_dropout
129
+
130
+ return TransformerEncoder(embedding_dim=embedding_dim,
131
+ number_of_heads=self.number_of_heads,
132
+ number_of_layers=max(self.number_of_layers, layers),
133
+ attention_dropout=attention_dropout,
134
+ relu_dropout=self.relu_dropout,
135
+ residual_dropout=self.residual_dropout,
136
+ embedding_dropout=self.embedding_dropout,
137
+ attention_mask=self.attention_mask)
138
+
139
+ def create_fusion_transformer(self, layers=-1):
140
+ return TransformerEncoder(embedding_dim=self.number_of_modalities * self.projected_modality_dim[0],
141
+ number_of_heads=self.number_of_heads,
142
+ number_of_layers=max(self.number_of_layers, layers),
143
+ attention_dropout=self.self_attention_dropout,
144
+ relu_dropout=self.relu_dropout,
145
+ residual_dropout=self.residual_dropout,
146
+ embedding_dropout=self.self_attention_dropout,
147
+ attention_mask=self.attention_mask)
148
+
149
+ def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
150
+ """Inference with Multimodal Transformer.
151
+
152
+ Args:
153
+ inputs (list[torch.Tensor]): input tensors of shape (B, T, F)
154
+
155
+ Returns:
156
+ (torch.Tensor | tuple[torch.Tensor, torch.Tensor]): tensor of shape (B, F) and/or (B, T, F)
157
+ """
158
+ # transpose and add embedding dropout
159
+ inp = [] # x_a, x_v, x_t
160
+ for input in inputs:
161
+ input_T = input.transpose(1, 2) # (B, T, F) -> (B, F, T)
162
+ if self.embedding_dropout > 0:
163
+ inp.append(F.dropout(input_T, p=self.embedding_dropout, training=self.training))
164
+ else:
165
+ inp.append(input_T)
166
+ logging.debug(f'input sizes: {[tuple(i.size()) for i in inp]}')
167
+
168
+ # temporal convolution projection of input tensors
169
+ proj_x_mod = [self.projectors[i](input).permute(0, 2, 1) for i, input in enumerate(inp)]
170
+ logging.debug(f'projected input sizes: {[tuple(i.size()) for i in proj_x_mod]}')
171
+
172
+ if self.aggregation == 'cls':
173
+ # add cls token to every input as the first timestamp
174
+ # (projected_dim,) -> (1, 1, projected_dim) -> (batch_size, 1, projected_dim)
175
+ cls_tokens = [
176
+ torch.zeros((proj_x_mod[i].shape[0], 1, proj_x_mod[i].shape[-1]), device=proj_x_mod[i].device)
177
+ for _ in range(self.number_of_modalities)
178
+ ]
179
+
180
+ proj_x_mod = [
181
+ torch.cat((cls_token, projected_representation), dim=1)
182
+ for projected_representation, cls_token in zip(proj_x_mod, cls_tokens)
183
+ ] # (B, T, F) -> (B, T+1, F)
184
+
185
+ # cross-modal transformers
186
+ hidden_representations = []
187
+ for target_index in range(self.number_of_modalities): # e.g. target_index == 0
188
+ input_indices = [ind for ind in self.modality_indices if ind != target_index] # e.g. input_indices = [1, 2]
189
+ cross_modal_hidden = []
190
+ for i, input_index in enumerate(input_indices):
191
+ # AVT: (V,T) --> A
192
+ logging.debug(f"Query: {[f'modality_{m}' for m in self.modality_indices][target_index]} with shape {tuple(proj_x_mod[target_index].size())} " + \
193
+ f"--> Keys, Values: {[f'modality_{m}' for m in self.modality_indices][input_index]} with shape {tuple(proj_x_mod[input_index].size())}")
194
+ cross_modal_hidden.append(
195
+ self.crossmodal_transformers[target_index][i](
196
+ proj_x_mod[target_index], proj_x_mod[input_index], proj_x_mod[input_index])
197
+ ) # Q, K, V
198
+ logging.debug(f"num of crossmodal transformers: {len(cross_modal_hidden)}, tensor shapes: {[tuple(elem.size()) for elem in cross_modal_hidden]}")
199
+
200
+ # self-attention transformer
201
+ cross_modal_hidden = torch.cat(cross_modal_hidden, dim=2) # within branch
202
+ self_hidden = self.self_attention_transformers[target_index](cross_modal_hidden)
203
+ hidden_representations.append(self_hidden) # (B, T, F) or (B, T+1, F)
204
+ logging.debug(f"last hidden representations with shapes: {[tuple(elem.size()) for elem in hidden_representations]}")
205
+
206
+ if self.add_time_collapse:
207
+ hidden_representation = torch.cat([hidden_representation[:,-1,:] for hidden_representation in hidden_representations], dim=-1) # [(B, T, F), ...] -> (B, combined_dim)
208
+ else:
209
+ hidden_representation = torch.cat(hidden_representations, dim=-1) # [(B, T, F), ...] -> (B, T, combined_dim)
210
+
211
+ if self.add_self_attention_fusion:
212
+ hidden_representation = self.self_attention_fusion_transformer(hidden_representation)
213
+
214
+ if self.add_projection_fusion:
215
+ hidden_representation = self.projection_2(F.dropout(F.relu(self.projection_1(hidden_representation)), p=self.output_dropout, training=self.training)) \
216
+ + hidden_representation # (B, T, combined_dim) or (B, combined_dim)
217
+
218
+ if self.add_time_collapse:
219
+ output_cls = self.out_layer(hidden_representation)
220
+ return output_cls
221
+ else:
222
+ match self.aggregation:
223
+ case 'last':
224
+ output_cls = self.out_layer(hidden_representation[:, -1, :]) # (B, combined_dim)
225
+ case 'cls':
226
+ output_cls = self.out_layer(hidden_representation[:, 0, :]) # (B, T+1, combined_dim) -> (B, combined_dim)
227
+ hidden_representation = hidden_representation[:, 1:, :] # (B, T+1, combined_dim) -> (B, T, combined_dim)
228
+ case 'maxpooling':
229
+ output_cls = self.out_layer(torch.max(hidden_representation, dim=1)) # (B, T, combined_dim) -> (B, combined_dim)
230
+ case _: # 'meanpooling'
231
+ output_cls = self.out_layer(torch.mean(hidden_representation, dim=1)) # (B, T, combined_dim) -> (B, combined_dim)
232
+
233
+ # output_cls head: sequence -> aggregation -> dense -> summarized logits
234
+ # output_seq head: sequence -> time-distributed dense -> sequence-wise logits
235
+ output_seq = self.out_layer(hidden_representation)
236
+ logging.debug(f"output output_cls shape: {tuple(output_cls.size())}")
237
+ logging.debug(f"output output_seq shape: {tuple(output_seq.size())}")
238
+ return output_cls, output_seq
@@ -0,0 +1,75 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from linmult.models.transformer import TransformerEncoder
5
+
6
+
7
+ class LinT(nn.Module):
8
+
9
+ def __init__(self,
10
+ input_modality_channels: int,
11
+ output_dim: int,
12
+ projected_modality_dim: int | list = 40, # d
13
+ number_of_heads: int = 8,
14
+ number_of_layers: int = 4, # D
15
+ embedding_dropout: float = 0.1,
16
+ cross_attention_dropout: float = 0.1,
17
+ self_attention_dropout: float = 0.0,
18
+ relu_dropout: float = 0.1,
19
+ residual_dropout: float = 0.1,
20
+ output_dropout: float = 0.1,
21
+ attention_mask: bool = True):
22
+ super().__init__()
23
+ self.input_modality_channels = input_modality_channels
24
+ self.output_dim = output_dim
25
+ self.projected_modality_dim = projected_modality_dim
26
+ self.number_of_heads = number_of_heads
27
+ self.number_of_layers = number_of_layers
28
+ self.embedding_dropout = embedding_dropout
29
+ self.cross_attention_dropout = cross_attention_dropout
30
+ self.self_attention_dropout = self_attention_dropout
31
+ self.relu_dropout = relu_dropout
32
+ self.residual_dropout = residual_dropout
33
+ self.output_dropout = output_dropout
34
+ self.attention_mask = attention_mask
35
+
36
+ # 1. Temporal convolutional layers
37
+ self.projector = nn.Conv1d(input_modality_channels,
38
+ projected_modality_dim,
39
+ kernel_size=1,
40
+ padding=0,
41
+ bias=False)
42
+
43
+ # 2. Self Attention Linear Transformer
44
+ self.self_attention_transformer = TransformerEncoder(
45
+ embedding_dim=self.projected_modality_dim,
46
+ number_of_heads=self.number_of_heads,
47
+ number_of_layers=self.number_of_layers,
48
+ attention_dropout=self.self_attention_dropout,
49
+ relu_dropout=self.relu_dropout,
50
+ residual_dropout=self.residual_dropout,
51
+ embedding_dropout=self.self_attention_dropout,
52
+ attention_mask=self.attention_mask)
53
+
54
+ # 3. Projection layer
55
+ self.out_layer = nn.Linear(self.projected_modality_dim, self.output_dim)
56
+
57
+ def forward(self, input: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
58
+ """input tensor of shape (B, T, F)"""
59
+
60
+ if isinstance(input, list):
61
+ if len(input) == 1:
62
+ input = input[0]
63
+ else:
64
+ raise Exception(f'A single tensor is expected got instead {len(input)}.')
65
+
66
+ input = input.transpose(1, 2)
67
+
68
+ if self.embedding_dropout > 0:
69
+ input = F.dropout(input, p=self.embedding_dropout, training=self.training)
70
+
71
+ proj_x = self.projector(input)
72
+ proj_x = proj_x.permute(0, 2, 1)
73
+ hidden_representation = self.self_attention_transformer(proj_x)
74
+ output_seq = self.out_layer(hidden_representation)
75
+ return output_seq
File without changes