linmult 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- linmult-1.1.0/LICENSE +21 -0
- linmult-1.1.0/PKG-INFO +7 -0
- linmult-1.1.0/README.md +143 -0
- linmult-1.1.0/linmult/__init__.py +2 -0
- linmult-1.1.0/linmult/models/LinMulT.py +238 -0
- linmult-1.1.0/linmult/models/LinT.py +75 -0
- linmult-1.1.0/linmult/models/__init__.py +0 -0
- linmult-1.1.0/linmult/models/linear_attention.py +217 -0
- linmult-1.1.0/linmult/models/masking.py +207 -0
- linmult-1.1.0/linmult/models/position_embedding.py +104 -0
- linmult-1.1.0/linmult/models/transformer.py +193 -0
- linmult-1.1.0/linmult/tests/__init__.py +0 -0
- linmult-1.1.0/linmult/tests/example_inference_cls_config.py +13 -0
- linmult-1.1.0/linmult/tests/example_inference_seq_config.py +12 -0
- linmult-1.1.0/linmult/tests/example_train_cls_config.py +23 -0
- linmult-1.1.0/linmult/tests/example_train_seq_config.py +25 -0
- linmult-1.1.0/linmult/tests/test_architecture.py +41 -0
- linmult-1.1.0/linmult.egg-info/PKG-INFO +7 -0
- linmult-1.1.0/linmult.egg-info/SOURCES.txt +21 -0
- linmult-1.1.0/linmult.egg-info/dependency_links.txt +1 -0
- linmult-1.1.0/linmult.egg-info/top_level.txt +1 -0
- linmult-1.1.0/setup.cfg +4 -0
- linmult-1.1.0/setup.py +10 -0
linmult-1.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2022 fodorad
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
linmult-1.1.0/PKG-INFO
ADDED
linmult-1.1.0/README.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# LinMulT
|
|
2
|
+
|
|
3
|
+
General-purpose Multimodal Transformer with Linear Complexity Attention Mechanism.
|
|
4
|
+
|
|
5
|
+
# Setup
|
|
6
|
+
|
|
7
|
+
### Environment
|
|
8
|
+
* Python 3.10+
|
|
9
|
+
* PyTorch and cuDNN 1.13.1+cu117
|
|
10
|
+
|
|
11
|
+
### Install package with pip+git
|
|
12
|
+
```
|
|
13
|
+
pip install -U git+https://github.com/fodorad/LinMulT.git
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
### Install package from repository root
|
|
17
|
+
```
|
|
18
|
+
git clone https://github.com/fodorad/LinMulT
|
|
19
|
+
cd LinMulT
|
|
20
|
+
pip install -e .
|
|
21
|
+
pip install -U -r requirements.txt
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
# Quick start
|
|
25
|
+
### Example 1:
|
|
26
|
+
Simple transformer encoder with linear attention.
|
|
27
|
+
The forward pass is performed using an input sequence.
|
|
28
|
+
```
|
|
29
|
+
import torch
|
|
30
|
+
from linmult import LinT
|
|
31
|
+
|
|
32
|
+
# input shape: (batch_size, time_dimension, feature_dimension)
|
|
33
|
+
x = torch.rand((32, 15, 1024), device='cuda')
|
|
34
|
+
model = LinT(input_modality_channels=1024, output_dim=5).cuda()
|
|
35
|
+
y_pred_seq = model(x)
|
|
36
|
+
|
|
37
|
+
# output shape: (batch_size, time_dimension, output_dimension)
|
|
38
|
+
assert y_pred_seq.size() == torch.Size([32, 15, 5])
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
### Example 2:
|
|
42
|
+
Multimodal Transformer with Linear Attention.
|
|
43
|
+
The forward pass is performed using 2 input sequences. Both input sequences have the same time dimension.
|
|
44
|
+
```
|
|
45
|
+
import torch
|
|
46
|
+
from linmult import LinMulT
|
|
47
|
+
|
|
48
|
+
# input shape: (batch_size, time_dimension, feature_dimension)
|
|
49
|
+
x_1 = torch.rand((32, 15, 1024), device='cuda')
|
|
50
|
+
x_2 = torch.rand((32, 15, 160), device='cuda')
|
|
51
|
+
model = LinMulT(input_modality_channels=[1024, 160], output_dim=5).cuda()
|
|
52
|
+
y_pred_cls, y_pred_seq = model([x_1, x_2])
|
|
53
|
+
|
|
54
|
+
# 1. output shape: (batch_size, output_dimension)
|
|
55
|
+
assert y_pred_cls.size() == torch.Size([32, 5])
|
|
56
|
+
|
|
57
|
+
# 2. output shape: (batch_size, time_dimension, output_dimension)
|
|
58
|
+
assert y_pred_seq.size() == torch.Size([32, 15, 5])
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
### Example 3:
|
|
62
|
+
Multimodal Transformer with Linear Attention. The forward pass is performed using 3 input sequences with different time dimensions.
|
|
63
|
+
```
|
|
64
|
+
import torch
|
|
65
|
+
from linmult import LinMulT
|
|
66
|
+
|
|
67
|
+
# input shape: (batch_size, time_dimension, feature_dimension)
|
|
68
|
+
x_1 = torch.rand((16, 1500, 25), device='cuda')
|
|
69
|
+
x_2 = torch.rand((16, 450, 35), device='cuda')
|
|
70
|
+
x_3 = torch.rand((16, 120, 768), device='cuda')
|
|
71
|
+
model = LinMulT(input_modality_channels=[25, 35, 768],
|
|
72
|
+
output_dim=5,
|
|
73
|
+
add_time_collapse=True,
|
|
74
|
+
add_self_attention_fusion=False).cuda()
|
|
75
|
+
y_pred_cls = model([x_1, x_2, x_3])
|
|
76
|
+
|
|
77
|
+
# output shape: (batch_size, output_dimension)
|
|
78
|
+
assert y_pred_cls.size() == torch.Size([16, 5])
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
# Run tests
|
|
82
|
+
```
|
|
83
|
+
python -m unittest
|
|
84
|
+
```
|
|
85
|
+
# Similar projects using LinMulT
|
|
86
|
+
|
|
87
|
+
### (2023) BlinkLinMulT
|
|
88
|
+
LinMulT is trained for blink presence detection and eye state recognition tasks.
|
|
89
|
+
Our results demonstrate comparable or superior performance compared to state-of-the-art models on 2 tasks, using 7 public benchmark databases.
|
|
90
|
+
* paper: BlinkLinMulT: Transformer-based Eye Blink Detection (accepted, available soon)
|
|
91
|
+
* code: https://github.com/fodorad/BlinkLinMulT
|
|
92
|
+
|
|
93
|
+
### (2022) PersonalityLinMulT
|
|
94
|
+
LinMulT is trained for Big Five personality trait estimation using the First Impressions V2 dataset and sentiment estimation using the MOSI and MOSEI datasets.
|
|
95
|
+
* paper: Multimodal Sentiment and Personality Perception Under Speech: A Comparison of Transformer-based Architectures ([pdf](https://proceedings.mlr.press/v173/fodor22a/fodor22a.pdf), [website](https://proceedings.mlr.press/v173/fodor22a.html))
|
|
96
|
+
* code: https://github.com/fodorad/PersonalityLinMulT
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# Citation - BibTex
|
|
100
|
+
If you found our research helpful or influential please consider citing:
|
|
101
|
+
|
|
102
|
+
### (2023) LinMulT for blink presence detection and eye state recognition:
|
|
103
|
+
```
|
|
104
|
+
@article{blinklinmult-fodor23,
|
|
105
|
+
title = {BlinkLinMulT: Transformer-based Eye Blink Detection},
|
|
106
|
+
author = {Fodor, {\'A}d{\'a}m and Fenech, Kristian and L{\H{o}}rincz, Andr{\'a}s},
|
|
107
|
+
journal = {...}
|
|
108
|
+
pages = {1--19},
|
|
109
|
+
year = {2023}
|
|
110
|
+
}
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### (2022) LinMulT for personality trait and sentiment estimation:
|
|
114
|
+
```
|
|
115
|
+
@InProceedings{pmlr-v173-fodor22a,
|
|
116
|
+
title = {Multimodal Sentiment and Personality Perception Under Speech: A Comparison of Transformer-based Architectures},
|
|
117
|
+
author = {Fodor, {\'A}d{\'a}m and Saboundji, Rachid R. and Jacques Junior, Julio C. S. and Escalera, Sergio and Gallardo-Pujol, David and L{\H{o}}rincz, Andr{\'a}s},
|
|
118
|
+
booktitle = {Understanding Social Behavior in Dyadic and Small Group Interactions},
|
|
119
|
+
pages = {218--241},
|
|
120
|
+
year = {2022},
|
|
121
|
+
editor = {Palmero, Cristina and Jacques Junior, Julio C. S. and Clapés, Albert and Guyon, Isabelle and Tu, Wei-Wei and Moeslund, Thomas B. and Escalera, Sergio},
|
|
122
|
+
volume = {173},
|
|
123
|
+
series = {Proceedings of Machine Learning Research},
|
|
124
|
+
month = {16 Oct},
|
|
125
|
+
publisher = {PMLR},
|
|
126
|
+
pdf = {https://proceedings.mlr.press/v173/fodor22a/fodor22a.pdf},
|
|
127
|
+
url = {https://proceedings.mlr.press/v173/fodor22a.html}
|
|
128
|
+
}
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
# Acknowledgement
|
|
132
|
+
The code is inspired by the following two materials:
|
|
133
|
+
|
|
134
|
+
### Multimodal Transformer:
|
|
135
|
+
* paper: Multimodal Transformer for Unaligned Multimodal Language Sequences ([1906.00295](https://arxiv.org/pdf/1906.00295.pdf))
|
|
136
|
+
* code: https://github.com/yaohungt/Multimodal-Transformer
|
|
137
|
+
|
|
138
|
+
### Linear Attention:
|
|
139
|
+
* paper: Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention ([2006.16236](https://arxiv.org/pdf/2006.16236.pdf))
|
|
140
|
+
* code: https://github.com/idiap/fast-transformers
|
|
141
|
+
|
|
142
|
+
# Contact
|
|
143
|
+
* Ádám Fodor (foauaai@inf.elte.hu)
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
##########################################################
|
|
2
|
+
# #
|
|
3
|
+
# Code is inspired by the following repositories: #
|
|
4
|
+
# https://github.com/yaohungt/Multimodal-Transformer #
|
|
5
|
+
# #
|
|
6
|
+
##########################################################
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Iterable
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from linmult.models.transformer import TransformerEncoder
|
|
13
|
+
|
|
14
|
+
logging.basicConfig(level=logging.INFO,
|
|
15
|
+
format="%(asctime)s %(levelname)s %(message)s",
|
|
16
|
+
datefmt="%Y-%m-%d %H:%M:%S")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LinMulT(nn.Module):
|
|
20
|
+
|
|
21
|
+
def __init__(self,
|
|
22
|
+
input_modality_channels: Iterable[int],
|
|
23
|
+
output_dim: int,
|
|
24
|
+
projected_modality_dim: int | list = 40, # d
|
|
25
|
+
number_of_heads: int = 8,
|
|
26
|
+
number_of_layers: int = 4, # D
|
|
27
|
+
embedding_dropout: float = 0.1,
|
|
28
|
+
cross_attention_dropout: float = 0.1,
|
|
29
|
+
self_attention_dropout: float = 0.0,
|
|
30
|
+
relu_dropout: float = 0.1,
|
|
31
|
+
residual_dropout: float = 0.1,
|
|
32
|
+
output_dropout: float = 0.1,
|
|
33
|
+
attention_mask: bool = True,
|
|
34
|
+
add_time_collapse: bool = False,
|
|
35
|
+
add_self_attention_fusion: bool = True,
|
|
36
|
+
add_projection_fusion: bool = True,
|
|
37
|
+
aggregation: str = 'meanpooling'):
|
|
38
|
+
"""Construct a MulT model with linear attention mechanism.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
aggregation (str | None): aggregation applied to the output sequence to get output_cls.
|
|
42
|
+
None - when add_time_collapse is True, aggregation is not used at all.
|
|
43
|
+
last - last timestep is used. Original MulT implementation.
|
|
44
|
+
cls - classification token is used.
|
|
45
|
+
meanpooling - mean is calculated over the T time dimension.
|
|
46
|
+
maxpooling - max is calculated over the T time dimension.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
if aggregation not in {None, 'last', 'cls', 'meanpooling', 'maxpooling'}:
|
|
51
|
+
raise Exception(f'Invalid aggregation {aggregation}.')
|
|
52
|
+
|
|
53
|
+
if add_time_collapse and add_self_attention_fusion:
|
|
54
|
+
raise Exception(f'These arguments cannot be True at the same time: {{add_time_collapse, add_self_attention_fusion}}')
|
|
55
|
+
|
|
56
|
+
self.input_modality_channels = input_modality_channels
|
|
57
|
+
self.output_dim = output_dim
|
|
58
|
+
self.number_of_modalities = len(self.input_modality_channels)
|
|
59
|
+
|
|
60
|
+
if isinstance(projected_modality_dim, int):
|
|
61
|
+
self.projected_modality_dim = [projected_modality_dim] * self.number_of_modalities
|
|
62
|
+
else: # list
|
|
63
|
+
if len(projected_modality_dim) != self.number_of_modalities:
|
|
64
|
+
raise Exception('Length of projected_modality_dim should be the number of modalities.')
|
|
65
|
+
self.projected_modality_dim = projected_modality_dim
|
|
66
|
+
|
|
67
|
+
self.number_of_heads = number_of_heads
|
|
68
|
+
self.number_of_layers = number_of_layers
|
|
69
|
+
self.embedding_dropout = embedding_dropout
|
|
70
|
+
self.cross_attention_dropout = cross_attention_dropout
|
|
71
|
+
self.self_attention_dropout = self_attention_dropout
|
|
72
|
+
self.relu_dropout = relu_dropout
|
|
73
|
+
self.residual_dropout = residual_dropout
|
|
74
|
+
self.output_dropout = output_dropout
|
|
75
|
+
self.attention_mask = attention_mask
|
|
76
|
+
self.add_time_collapse = add_time_collapse
|
|
77
|
+
self.add_self_attention_fusion = add_self_attention_fusion
|
|
78
|
+
self.add_projection_fusion = add_projection_fusion
|
|
79
|
+
self.aggregation = aggregation if not add_time_collapse else None
|
|
80
|
+
combined_dim = (self.number_of_modalities - 1) * torch.tensor(self.projected_modality_dim).sum()
|
|
81
|
+
|
|
82
|
+
# 1. Temporal Convolutional Layers
|
|
83
|
+
self.projectors = nn.ModuleList([
|
|
84
|
+
nn.Conv1d(input_modality_channels, projected_modality_dim, kernel_size=1, padding=0, bias=False)
|
|
85
|
+
for input_modality_channels, projected_modality_dim
|
|
86
|
+
in zip(self.input_modality_channels, self.projected_modality_dim)
|
|
87
|
+
])
|
|
88
|
+
|
|
89
|
+
# 2. Crossmodal Attention Transformers
|
|
90
|
+
# e.g.: a, v, t modalities correspond to 0, 1, 2 indices
|
|
91
|
+
# Q -> a, K and V -> v, t: v t - 1 2
|
|
92
|
+
# Q -> v, K and V -> a, t: a t - 0 2
|
|
93
|
+
# Q -> t, K and V -> a, v: a v - 0 1
|
|
94
|
+
self.modality_indices = range(self.number_of_modalities)
|
|
95
|
+
self.crossmodal_transformers = nn.ModuleList([])
|
|
96
|
+
for target_index in self.modality_indices: # e.g. target_index = 0
|
|
97
|
+
input_indices = [ind for ind in self.modality_indices if ind != target_index] # e.g. input_indices = [1, 2]
|
|
98
|
+
self.crossmodal_transformers.append(
|
|
99
|
+
nn.ModuleList([
|
|
100
|
+
self.create_transformer(modality_index=input_index, attention_type='cross')
|
|
101
|
+
for input_index in input_indices
|
|
102
|
+
])
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# 3. Self Attention Transformers
|
|
106
|
+
self.self_attention_transformers = nn.ModuleList([
|
|
107
|
+
self.create_transformer(modality_index=target_index, attention_type='self', layers=3)
|
|
108
|
+
for target_index in self.modality_indices
|
|
109
|
+
])
|
|
110
|
+
|
|
111
|
+
# 4. Self Attention Fusion Transformer
|
|
112
|
+
if self.add_self_attention_fusion:
|
|
113
|
+
self.self_attention_fusion_transformer = self.create_fusion_transformer()
|
|
114
|
+
|
|
115
|
+
if self.add_projection_fusion:
|
|
116
|
+
self.projection_1 = nn.Linear(combined_dim, combined_dim)
|
|
117
|
+
self.projection_2 = nn.Linear(combined_dim, combined_dim)
|
|
118
|
+
|
|
119
|
+
# 5. Sequence Head & Aggregation
|
|
120
|
+
self.out_layer = nn.Linear(combined_dim, self.output_dim) # (B, T, output_dim) or (B, output_dim)
|
|
121
|
+
|
|
122
|
+
def create_transformer(self, modality_index, attention_type: str, layers=-1):
|
|
123
|
+
if attention_type == 'cross': # Crossmodal Attention Transformer
|
|
124
|
+
embedding_dim = self.projected_modality_dim[modality_index]
|
|
125
|
+
attention_dropout = self.cross_attention_dropout
|
|
126
|
+
else: # Self Attention Transformer
|
|
127
|
+
embedding_dim = (self.number_of_modalities - 1) * self.projected_modality_dim[modality_index]
|
|
128
|
+
attention_dropout = self.self_attention_dropout
|
|
129
|
+
|
|
130
|
+
return TransformerEncoder(embedding_dim=embedding_dim,
|
|
131
|
+
number_of_heads=self.number_of_heads,
|
|
132
|
+
number_of_layers=max(self.number_of_layers, layers),
|
|
133
|
+
attention_dropout=attention_dropout,
|
|
134
|
+
relu_dropout=self.relu_dropout,
|
|
135
|
+
residual_dropout=self.residual_dropout,
|
|
136
|
+
embedding_dropout=self.embedding_dropout,
|
|
137
|
+
attention_mask=self.attention_mask)
|
|
138
|
+
|
|
139
|
+
def create_fusion_transformer(self, layers=-1):
|
|
140
|
+
return TransformerEncoder(embedding_dim=self.number_of_modalities * self.projected_modality_dim[0],
|
|
141
|
+
number_of_heads=self.number_of_heads,
|
|
142
|
+
number_of_layers=max(self.number_of_layers, layers),
|
|
143
|
+
attention_dropout=self.self_attention_dropout,
|
|
144
|
+
relu_dropout=self.relu_dropout,
|
|
145
|
+
residual_dropout=self.residual_dropout,
|
|
146
|
+
embedding_dropout=self.self_attention_dropout,
|
|
147
|
+
attention_mask=self.attention_mask)
|
|
148
|
+
|
|
149
|
+
def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
150
|
+
"""Inference with Multimodal Transformer.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
inputs (list[torch.Tensor]): input tensors of shape (B, T, F)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
(torch.Tensor | tuple[torch.Tensor, torch.Tensor]): tensor of shape (B, F) and/or (B, T, F)
|
|
157
|
+
"""
|
|
158
|
+
# transpose and add embedding dropout
|
|
159
|
+
inp = [] # x_a, x_v, x_t
|
|
160
|
+
for input in inputs:
|
|
161
|
+
input_T = input.transpose(1, 2) # (B, T, F) -> (B, F, T)
|
|
162
|
+
if self.embedding_dropout > 0:
|
|
163
|
+
inp.append(F.dropout(input_T, p=self.embedding_dropout, training=self.training))
|
|
164
|
+
else:
|
|
165
|
+
inp.append(input_T)
|
|
166
|
+
logging.debug(f'input sizes: {[tuple(i.size()) for i in inp]}')
|
|
167
|
+
|
|
168
|
+
# temporal convolution projection of input tensors
|
|
169
|
+
proj_x_mod = [self.projectors[i](input).permute(0, 2, 1) for i, input in enumerate(inp)]
|
|
170
|
+
logging.debug(f'projected input sizes: {[tuple(i.size()) for i in proj_x_mod]}')
|
|
171
|
+
|
|
172
|
+
if self.aggregation == 'cls':
|
|
173
|
+
# add cls token to every input as the first timestamp
|
|
174
|
+
# (projected_dim,) -> (1, 1, projected_dim) -> (batch_size, 1, projected_dim)
|
|
175
|
+
cls_tokens = [
|
|
176
|
+
torch.zeros((proj_x_mod[i].shape[0], 1, proj_x_mod[i].shape[-1]), device=proj_x_mod[i].device)
|
|
177
|
+
for _ in range(self.number_of_modalities)
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
proj_x_mod = [
|
|
181
|
+
torch.cat((cls_token, projected_representation), dim=1)
|
|
182
|
+
for projected_representation, cls_token in zip(proj_x_mod, cls_tokens)
|
|
183
|
+
] # (B, T, F) -> (B, T+1, F)
|
|
184
|
+
|
|
185
|
+
# cross-modal transformers
|
|
186
|
+
hidden_representations = []
|
|
187
|
+
for target_index in range(self.number_of_modalities): # e.g. target_index == 0
|
|
188
|
+
input_indices = [ind for ind in self.modality_indices if ind != target_index] # e.g. input_indices = [1, 2]
|
|
189
|
+
cross_modal_hidden = []
|
|
190
|
+
for i, input_index in enumerate(input_indices):
|
|
191
|
+
# AVT: (V,T) --> A
|
|
192
|
+
logging.debug(f"Query: {[f'modality_{m}' for m in self.modality_indices][target_index]} with shape {tuple(proj_x_mod[target_index].size())} " + \
|
|
193
|
+
f"--> Keys, Values: {[f'modality_{m}' for m in self.modality_indices][input_index]} with shape {tuple(proj_x_mod[input_index].size())}")
|
|
194
|
+
cross_modal_hidden.append(
|
|
195
|
+
self.crossmodal_transformers[target_index][i](
|
|
196
|
+
proj_x_mod[target_index], proj_x_mod[input_index], proj_x_mod[input_index])
|
|
197
|
+
) # Q, K, V
|
|
198
|
+
logging.debug(f"num of crossmodal transformers: {len(cross_modal_hidden)}, tensor shapes: {[tuple(elem.size()) for elem in cross_modal_hidden]}")
|
|
199
|
+
|
|
200
|
+
# self-attention transformer
|
|
201
|
+
cross_modal_hidden = torch.cat(cross_modal_hidden, dim=2) # within branch
|
|
202
|
+
self_hidden = self.self_attention_transformers[target_index](cross_modal_hidden)
|
|
203
|
+
hidden_representations.append(self_hidden) # (B, T, F) or (B, T+1, F)
|
|
204
|
+
logging.debug(f"last hidden representations with shapes: {[tuple(elem.size()) for elem in hidden_representations]}")
|
|
205
|
+
|
|
206
|
+
if self.add_time_collapse:
|
|
207
|
+
hidden_representation = torch.cat([hidden_representation[:,-1,:] for hidden_representation in hidden_representations], dim=-1) # [(B, T, F), ...] -> (B, combined_dim)
|
|
208
|
+
else:
|
|
209
|
+
hidden_representation = torch.cat(hidden_representations, dim=-1) # [(B, T, F), ...] -> (B, T, combined_dim)
|
|
210
|
+
|
|
211
|
+
if self.add_self_attention_fusion:
|
|
212
|
+
hidden_representation = self.self_attention_fusion_transformer(hidden_representation)
|
|
213
|
+
|
|
214
|
+
if self.add_projection_fusion:
|
|
215
|
+
hidden_representation = self.projection_2(F.dropout(F.relu(self.projection_1(hidden_representation)), p=self.output_dropout, training=self.training)) \
|
|
216
|
+
+ hidden_representation # (B, T, combined_dim) or (B, combined_dim)
|
|
217
|
+
|
|
218
|
+
if self.add_time_collapse:
|
|
219
|
+
output_cls = self.out_layer(hidden_representation)
|
|
220
|
+
return output_cls
|
|
221
|
+
else:
|
|
222
|
+
match self.aggregation:
|
|
223
|
+
case 'last':
|
|
224
|
+
output_cls = self.out_layer(hidden_representation[:, -1, :]) # (B, combined_dim)
|
|
225
|
+
case 'cls':
|
|
226
|
+
output_cls = self.out_layer(hidden_representation[:, 0, :]) # (B, T+1, combined_dim) -> (B, combined_dim)
|
|
227
|
+
hidden_representation = hidden_representation[:, 1:, :] # (B, T+1, combined_dim) -> (B, T, combined_dim)
|
|
228
|
+
case 'maxpooling':
|
|
229
|
+
output_cls = self.out_layer(torch.max(hidden_representation, dim=1)) # (B, T, combined_dim) -> (B, combined_dim)
|
|
230
|
+
case _: # 'meanpooling'
|
|
231
|
+
output_cls = self.out_layer(torch.mean(hidden_representation, dim=1)) # (B, T, combined_dim) -> (B, combined_dim)
|
|
232
|
+
|
|
233
|
+
# output_cls head: sequence -> aggregation -> dense -> summarized logits
|
|
234
|
+
# output_seq head: sequence -> time-distributed dense -> sequence-wise logits
|
|
235
|
+
output_seq = self.out_layer(hidden_representation)
|
|
236
|
+
logging.debug(f"output output_cls shape: {tuple(output_cls.size())}")
|
|
237
|
+
logging.debug(f"output output_seq shape: {tuple(output_seq.size())}")
|
|
238
|
+
return output_cls, output_seq
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from linmult.models.transformer import TransformerEncoder
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LinT(nn.Module):
|
|
8
|
+
|
|
9
|
+
def __init__(self,
|
|
10
|
+
input_modality_channels: int,
|
|
11
|
+
output_dim: int,
|
|
12
|
+
projected_modality_dim: int | list = 40, # d
|
|
13
|
+
number_of_heads: int = 8,
|
|
14
|
+
number_of_layers: int = 4, # D
|
|
15
|
+
embedding_dropout: float = 0.1,
|
|
16
|
+
cross_attention_dropout: float = 0.1,
|
|
17
|
+
self_attention_dropout: float = 0.0,
|
|
18
|
+
relu_dropout: float = 0.1,
|
|
19
|
+
residual_dropout: float = 0.1,
|
|
20
|
+
output_dropout: float = 0.1,
|
|
21
|
+
attention_mask: bool = True):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.input_modality_channels = input_modality_channels
|
|
24
|
+
self.output_dim = output_dim
|
|
25
|
+
self.projected_modality_dim = projected_modality_dim
|
|
26
|
+
self.number_of_heads = number_of_heads
|
|
27
|
+
self.number_of_layers = number_of_layers
|
|
28
|
+
self.embedding_dropout = embedding_dropout
|
|
29
|
+
self.cross_attention_dropout = cross_attention_dropout
|
|
30
|
+
self.self_attention_dropout = self_attention_dropout
|
|
31
|
+
self.relu_dropout = relu_dropout
|
|
32
|
+
self.residual_dropout = residual_dropout
|
|
33
|
+
self.output_dropout = output_dropout
|
|
34
|
+
self.attention_mask = attention_mask
|
|
35
|
+
|
|
36
|
+
# 1. Temporal convolutional layers
|
|
37
|
+
self.projector = nn.Conv1d(input_modality_channels,
|
|
38
|
+
projected_modality_dim,
|
|
39
|
+
kernel_size=1,
|
|
40
|
+
padding=0,
|
|
41
|
+
bias=False)
|
|
42
|
+
|
|
43
|
+
# 2. Self Attention Linear Transformer
|
|
44
|
+
self.self_attention_transformer = TransformerEncoder(
|
|
45
|
+
embedding_dim=self.projected_modality_dim,
|
|
46
|
+
number_of_heads=self.number_of_heads,
|
|
47
|
+
number_of_layers=self.number_of_layers,
|
|
48
|
+
attention_dropout=self.self_attention_dropout,
|
|
49
|
+
relu_dropout=self.relu_dropout,
|
|
50
|
+
residual_dropout=self.residual_dropout,
|
|
51
|
+
embedding_dropout=self.self_attention_dropout,
|
|
52
|
+
attention_mask=self.attention_mask)
|
|
53
|
+
|
|
54
|
+
# 3. Projection layer
|
|
55
|
+
self.out_layer = nn.Linear(self.projected_modality_dim, self.output_dim)
|
|
56
|
+
|
|
57
|
+
def forward(self, input: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
|
|
58
|
+
"""input tensor of shape (B, T, F)"""
|
|
59
|
+
|
|
60
|
+
if isinstance(input, list):
|
|
61
|
+
if len(input) == 1:
|
|
62
|
+
input = input[0]
|
|
63
|
+
else:
|
|
64
|
+
raise Exception(f'A single tensor is expected got instead {len(input)}.')
|
|
65
|
+
|
|
66
|
+
input = input.transpose(1, 2)
|
|
67
|
+
|
|
68
|
+
if self.embedding_dropout > 0:
|
|
69
|
+
input = F.dropout(input, p=self.embedding_dropout, training=self.training)
|
|
70
|
+
|
|
71
|
+
proj_x = self.projector(input)
|
|
72
|
+
proj_x = proj_x.permute(0, 2, 1)
|
|
73
|
+
hidden_representation = self.self_attention_transformer(proj_x)
|
|
74
|
+
output_seq = self.out_layer(hidden_representation)
|
|
75
|
+
return output_seq
|
|
File without changes
|