secryst 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/README.adoc +103 -0
- data/lib/secryst-trainer.rb +8 -0
- data/lib/secryst.rb +11 -0
- data/lib/secryst/clip_grad_norm.rb +25 -0
- data/lib/secryst/multi_head_attention_forward.rb +288 -0
- data/lib/secryst/multihead_attention.rb +156 -0
- data/lib/secryst/trainer.rb +235 -0
- data/lib/secryst/transformer.rb +382 -0
- data/lib/secryst/translator.rb +51 -0
- data/lib/secryst/version.rb +3 -0
- data/lib/secryst/vocab.rb +88 -0
- metadata +95 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 919071d9eb29b220b762212fafe8f7f7ace0f0605cde7c0402dbbe8eac2584c6
|
4
|
+
data.tar.gz: de1b20e79c10d14261c71e0043e82a196b38b531da2fcdec0db44bcb992c7239
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 94bfb8aff36218341cd5baa3bcb0eb48e5be147c54f4a8b072e25a82ae027c07bacc7769fafb7f657cff9f638f9a0a4a21385108f0a45384bdf1a7796c74ea67
|
7
|
+
data.tar.gz: a1fee058281400822908c30a2d5fcf27ceb320dfed8d78878986460323d4028e78fa89db1579146870864fd2104238ee038e9f7468710863ed16ab6c8652ce03
|
data/README.adoc
ADDED
@@ -0,0 +1,103 @@
|
|
1
|
+
= Secryst
|
2
|
+
|
3
|
+
image:https://github.com/secryst/secryst/workflows/test/badge.svg["Build status", link="https://github.com/secryst/secryst/actions?workflow=test"]
|
4
|
+
|
5
|
+
== Purpose
|
6
|
+
|
7
|
+
A seq2seq transformer suited for transliteration. Written in Ruby.
|
8
|
+
|
9
|
+
Secryst was originally built for the
|
10
|
+
https://www.interscript.com[Interscript project]
|
11
|
+
(https://github.com/secryst/secryst[at GitHub]).
|
12
|
+
|
13
|
+
The goal is to allow:
|
14
|
+
|
15
|
+
* Developers to train models and provide the trained model to users. In order to to train models, raw computing and their bindings can be used, e.g. OpenCL.
|
16
|
+
|
17
|
+
* Users of the library in Ruby who only want to "use" the trained models should not require special bindings to run.
|
18
|
+
|
19
|
+
|
20
|
+
== Status
|
21
|
+
|
22
|
+
Currently Secryst works with the Khmer Romanization system as cited below.
|
23
|
+
|
24
|
+
|
25
|
+
== Prerequisites
|
26
|
+
|
27
|
+
* Ruby 2.7 (*MUST* - 2.6 does not work with the latest torch-rb)
|
28
|
+
|
29
|
+
* `libtorch` (1.6.0)
|
30
|
+
* `fftw`
|
31
|
+
* `gsl`
|
32
|
+
* `lapack`
|
33
|
+
* `openblas`
|
34
|
+
|
35
|
+
|
36
|
+
On Ubuntu:
|
37
|
+
|
38
|
+
[source,sh]
|
39
|
+
----
|
40
|
+
$ sudo apt-get -y install libfftw3-dev libgsl-dev libopenblas-dev \
|
41
|
+
liblapack-dev liblapacke-dev unzip automake make gcc g++ \
|
42
|
+
libtorch libtorch-dev
|
43
|
+
$ wget https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
|
44
|
+
$ unzip libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
|
45
|
+
|
46
|
+
$ gem install bundler -v "~> 2"
|
47
|
+
$ bundle config build.torch-rb \
|
48
|
+
--with-torch-dir=$(pwd)/libtorch
|
49
|
+
|
50
|
+
$ bundle install
|
51
|
+
----
|
52
|
+
|
53
|
+
|
54
|
+
On macOS:
|
55
|
+
|
56
|
+
[source,sh]
|
57
|
+
----
|
58
|
+
$ brew install libtorch gsl lapack openblas fftw automake gcc
|
59
|
+
|
60
|
+
$ gem install bundler -v "~> 2"
|
61
|
+
$ bundle config build.numo-linalg \
|
62
|
+
--with-openblas-dir=/usr/local/opt/openblas \
|
63
|
+
--with-lapack-lib=/usr/local/opt/lapack
|
64
|
+
|
65
|
+
$ bundle install
|
66
|
+
----
|
67
|
+
|
68
|
+
|
69
|
+
NOTE: (for macOS)
|
70
|
+
If you mistakenly installed `numo-linalg` without the above configuration
|
71
|
+
options, please uninstall it with these steps and configure the bundle as
|
72
|
+
described above:
|
73
|
+
|
74
|
+
[source,sh]
|
75
|
+
----
|
76
|
+
$ bundle exec gem uninstall numo-linalg
|
77
|
+
----
|
78
|
+
|
79
|
+
|
80
|
+
|
81
|
+
== References
|
82
|
+
|
83
|
+
Secryst is built on the transformer model with architecture
|
84
|
+
based on:
|
85
|
+
|
86
|
+
* Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
87
|
+
Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin.
|
88
|
+
Attention is all you need. 2017. In:
|
89
|
+
_Advances in Neural Information Processing Systems_, pages 6000-6010.
|
90
|
+
|
91
|
+
The sample transliteration system implemented is the Khmer system:
|
92
|
+
|
93
|
+
* https://viblo.asia/p/nlp-khmer-word-segmentation-YWOZrgNNlQ0
|
94
|
+
* https://viblo.asia/p/nlp-khmer-romanization-using-seq2seq-m68Z07OQKkG
|
95
|
+
|
96
|
+
|
97
|
+
== Origin of name
|
98
|
+
|
99
|
+
Scrying is the practice of peering into a crystal sphere for fortune telling.
|
100
|
+
The purpose of `seq2seq` is nearly like scrying: looking into a crystal sphere
|
101
|
+
for some machine-learning magic to happen.
|
102
|
+
|
103
|
+
"`Secryst`" comes from the combination of "`seq2seq`" + "`crystal`" + "`scrying`".
|
data/lib/secryst.rb
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
# ported from https://pytorch.org/docs/master/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
|
2
|
+
|
3
|
+
module Secryst
|
4
|
+
class ClipGradNorm < Torch::NN::F
|
5
|
+
def self.clip_grad_norm(parameters, max_norm:, norm_type:2)
|
6
|
+
parameters = parameters.select {|p| p.grad }
|
7
|
+
max_norm = max_norm.to_f
|
8
|
+
if parameters.length == 0
|
9
|
+
return Torch.tensor(0.0)
|
10
|
+
end
|
11
|
+
device = parameters[0].grad.device
|
12
|
+
if norm_type == Float::INFINITY
|
13
|
+
# ... TODO
|
14
|
+
else
|
15
|
+
total_norm = Numo::Linalg.norm(Numo::NArray.concatenate(parameters.map {|p| Numo::Linalg.norm(p.grad.detach.numo, norm_type)}), norm_type)
|
16
|
+
end
|
17
|
+
clip_coef = max_norm / (total_norm + 1e-6)
|
18
|
+
if clip_coef < 1
|
19
|
+
parameters.each {|p| p.grad = p.grad.detach * clip_coef}
|
20
|
+
end
|
21
|
+
|
22
|
+
return total_norm
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,288 @@
|
|
1
|
+
module Secryst
|
2
|
+
class MultiHeadAttentionForward < Torch::NN::F
|
3
|
+
# Args:
|
4
|
+
# query, key, value: map a query and a set of key-value pairs to an output.
|
5
|
+
# See "Attention Is All You Need" for more details.
|
6
|
+
# embed_dim_to_check: total dimension of the model.
|
7
|
+
# num_heads: parallel attention heads.
|
8
|
+
# in_proj_weight, in_proj_bias: input projection weight and bias.
|
9
|
+
# bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
10
|
+
# add_zero_attn: add a new batch of zeros to the key and
|
11
|
+
# value sequences at dim=1.
|
12
|
+
# dropout_p: probability of an element to be zeroed.
|
13
|
+
# out_proj_weight, out_proj_bias: the output projection weight and bias.
|
14
|
+
# training: apply dropout if is ``true``.
|
15
|
+
# key_padding_mask: if provided, specified padding elements in the key will
|
16
|
+
# be ignored by the attention. This is a binary mask. When the value is true,
|
17
|
+
# the corresponding value on the attention layer will be filled with -inf.
|
18
|
+
# need_weights: output attn_output_weights.
|
19
|
+
# attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
20
|
+
# the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
21
|
+
# use_separate_proj_weight: the function accept the proj. weights for query, key,
|
22
|
+
# and value in different forms. If false, in_proj_weight will be used, which is
|
23
|
+
# a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
24
|
+
# q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
25
|
+
# static_k, static_v: static key and value used for attention operators.
|
26
|
+
# Shape:
|
27
|
+
# Inputs:
|
28
|
+
# - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
29
|
+
# the embedding dimension.
|
30
|
+
# - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
31
|
+
# the embedding dimension.
|
32
|
+
# - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
33
|
+
# the embedding dimension.
|
34
|
+
# - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
35
|
+
# If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
36
|
+
# will be unchanged. If a BoolTensor is provided, the positions with the
|
37
|
+
# value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
|
38
|
+
# - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
39
|
+
# 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
40
|
+
# S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
41
|
+
# positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
42
|
+
# while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
|
43
|
+
# are not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
|
44
|
+
# is provided, it will be added to the attention weight.
|
45
|
+
# - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
46
|
+
# N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
47
|
+
# - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
48
|
+
# N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
49
|
+
# Outputs:
|
50
|
+
# - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
51
|
+
# E is the embedding dimension.
|
52
|
+
# - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
53
|
+
# L is the target sequence length, S is the source sequence length.
|
54
|
+
def self.multi_head_attention_forward(query,
|
55
|
+
key,
|
56
|
+
value,
|
57
|
+
embed_dim_to_check,
|
58
|
+
num_heads,
|
59
|
+
in_proj_weight,
|
60
|
+
in_proj_bias,
|
61
|
+
bias_k,
|
62
|
+
bias_v,
|
63
|
+
add_zero_attn,
|
64
|
+
dropout_p,
|
65
|
+
out_proj_weight,
|
66
|
+
out_proj_bias,
|
67
|
+
training: true,
|
68
|
+
key_padding_mask: nil,
|
69
|
+
need_weights: true,
|
70
|
+
attn_mask: nil,
|
71
|
+
use_separate_proj_weight: false,
|
72
|
+
q_proj_weight: nil,
|
73
|
+
k_proj_weight: nil,
|
74
|
+
v_proj_weight: nil,
|
75
|
+
static_k: nil,
|
76
|
+
static_v: nil)
|
77
|
+
tgt_len, bsz, embed_dim = query.size()
|
78
|
+
raise ArgumentError if embed_dim != embed_dim_to_check
|
79
|
+
# allow MHA to have different sizes for the feature dimension
|
80
|
+
raise ArgumentError if key.size(0) != value.size(0) or key.size(1) != value.size(1)
|
81
|
+
|
82
|
+
head_dim = embed_dim / num_heads
|
83
|
+
raise ArgumentError, "embed_dim must be divisible by num_heads" if head_dim * num_heads != embed_dim
|
84
|
+
scaling = head_dim.to_f ** -0.5
|
85
|
+
|
86
|
+
if !use_separate_proj_weight
|
87
|
+
if Torch.equal(query, key) && Torch.equal(key, value)
|
88
|
+
# self-attention
|
89
|
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, -1)
|
90
|
+
|
91
|
+
elsif Torch.equal(key, value)
|
92
|
+
# encoder-decoder attention
|
93
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
94
|
+
_b = in_proj_bias
|
95
|
+
_start = 0
|
96
|
+
_end = embed_dim
|
97
|
+
_w = in_proj_weight.slice(0, _start, _end) # NOTE: inc-trspl
|
98
|
+
if _b
|
99
|
+
_b = _b.slice(0, _start, _end)
|
100
|
+
end
|
101
|
+
q = linear(query, _w, _b)
|
102
|
+
|
103
|
+
if !key
|
104
|
+
raise ArgumentError if value
|
105
|
+
k = nil
|
106
|
+
v = nil
|
107
|
+
else
|
108
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
109
|
+
_b = in_proj_bias
|
110
|
+
_start = embed_dim
|
111
|
+
_end = nil
|
112
|
+
_w = in_proj_weight.slice(0, _start)
|
113
|
+
if _b
|
114
|
+
_b = _b.slice(0, _start)
|
115
|
+
end
|
116
|
+
k, v = linear(key, _w, _b).chunk(2, -1)
|
117
|
+
end
|
118
|
+
|
119
|
+
else
|
120
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
121
|
+
_b = in_proj_bias
|
122
|
+
_start = 0
|
123
|
+
_end = embed_dim
|
124
|
+
_w = in_proj_weight.slice(0, _start, _end)
|
125
|
+
if _b
|
126
|
+
_b = _b.slice(0, _start, _end)
|
127
|
+
end
|
128
|
+
q = linear(query, _w, _b)
|
129
|
+
|
130
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
131
|
+
_b = in_proj_bias
|
132
|
+
_start = embed_dim
|
133
|
+
_end = embed_dim * 2
|
134
|
+
_w = in_proj_weight.slice(0, _start, _end)
|
135
|
+
if _b
|
136
|
+
_b = _b.slice(0, _start, _end)
|
137
|
+
end
|
138
|
+
k = linear(key, _w, _b)
|
139
|
+
|
140
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
141
|
+
_b = in_proj_bias
|
142
|
+
_start = embed_dim * 2
|
143
|
+
_end = nil
|
144
|
+
_w = in_proj_weight.slice(0, _start)
|
145
|
+
if _b
|
146
|
+
_b = _b.slice(0, _start)
|
147
|
+
end
|
148
|
+
v = linear(value, _w, _b)
|
149
|
+
end
|
150
|
+
else
|
151
|
+
q_proj_weight_non_opt = q_proj_weight
|
152
|
+
len1, len2 = q_proj_weight_non_opt.size()
|
153
|
+
raise ArgumentError if len1 != embed_dim || len2 != query.size(-1)
|
154
|
+
|
155
|
+
k_proj_weight_non_opt = k_proj_weight
|
156
|
+
len1, len2 = k_proj_weight_non_opt.size()
|
157
|
+
raise ArgumentError if len1 != embed_dim || len2 != key.size(-1)
|
158
|
+
|
159
|
+
v_proj_weight_non_opt = v_proj_weight
|
160
|
+
len1, len2 = v_proj_weight_non_opt.size()
|
161
|
+
raise ArgumentError if len1 != embed_dim || len2 != value.size(-1)
|
162
|
+
|
163
|
+
if in_proj_bias
|
164
|
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias.slice(0,0,embed_dim))
|
165
|
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim, embed_dim * 2))
|
166
|
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim * 2))
|
167
|
+
else
|
168
|
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
169
|
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
170
|
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
171
|
+
end
|
172
|
+
end
|
173
|
+
q = q * scaling
|
174
|
+
|
175
|
+
if attn_mask
|
176
|
+
raise ArgumentError, 'Only float, byte, and bool types are supported for attn_mask, not %s' % attn_mask.dtype unless attn_mask.dtype == Torch.float32 || attn_mask.dtype == Torch.float64 || attn_mask.dtype == Torch.float16 || attn_mask.dtype == Torch.uint8 || attn_mask.dtype == Torch.bool
|
177
|
+
if attn_mask.dtype == Torch.uint8
|
178
|
+
puts "Byte tensor for attn_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead."
|
179
|
+
attn_mask = attn_mask.to(Torch.bool)
|
180
|
+
end
|
181
|
+
|
182
|
+
if attn_mask.dim() == 2
|
183
|
+
attn_mask = attn_mask.unsqueeze(0)
|
184
|
+
raise ArgumentError, 'The size of the 2D attn_mask is not correct.' if attn_mask.size() != [1, query.size(0), key.size(0)]
|
185
|
+
elsif attn_mask.dim() == 3
|
186
|
+
raise ArgumentError, 'The size of the 3D attn_mask is not correct.' if attn_mask.size() != [bsz * num_heads, query.size(0), key.size(0)]
|
187
|
+
else
|
188
|
+
raise ArgumentError, "attn_mask's dimension %s is not supported" % attn_mask.dim()
|
189
|
+
end
|
190
|
+
# attn_mask's dim is 3 now.
|
191
|
+
end
|
192
|
+
|
193
|
+
# convert ByteTensor key_padding_mask to bool
|
194
|
+
if key_padding_mask && key_padding_mask.dtype == Torch.uint8
|
195
|
+
puts("Byte tensor for key_padding_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead.")
|
196
|
+
key_padding_mask = key_padding_mask.to(Torch.bool)
|
197
|
+
end
|
198
|
+
|
199
|
+
if bias_k && bias_v
|
200
|
+
if !static_k && !static_v
|
201
|
+
k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
202
|
+
v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
203
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
204
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
205
|
+
else
|
206
|
+
raise ArgumentError, "bias cannot be added to static key." unless !static_k
|
207
|
+
raise ArgumentError, "bias cannot be added to static value." unless !static_v
|
208
|
+
end
|
209
|
+
else
|
210
|
+
raise ArgumentError unless !bias_k
|
211
|
+
raise ArgumentError unless !bias_v
|
212
|
+
end
|
213
|
+
|
214
|
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
215
|
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if k
|
216
|
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v
|
217
|
+
|
218
|
+
if static_k
|
219
|
+
raise ArgumentError unless static_k.size(0) == bsz * num_heads
|
220
|
+
raise ArgumentError unless static_k.size(2) == head_dim
|
221
|
+
k = static_k
|
222
|
+
end
|
223
|
+
|
224
|
+
if static_v
|
225
|
+
raise ArgumentError unless static_v.size(0) == bsz * num_heads
|
226
|
+
raise ArgumentError unless static_v.size(2) == head_dim
|
227
|
+
v = static_v
|
228
|
+
end
|
229
|
+
|
230
|
+
src_len = k.size(1)
|
231
|
+
|
232
|
+
if key_padding_mask
|
233
|
+
raise ArgumentError unless key_padding_mask.size(0) == bsz
|
234
|
+
raise ArgumentError unless key_padding_mask.size(1) == src_len
|
235
|
+
end
|
236
|
+
|
237
|
+
if add_zero_attn
|
238
|
+
src_len += 1
|
239
|
+
k_sizes = k.size()
|
240
|
+
k_sizes[1] = 1
|
241
|
+
k = Torch.cat([k, Torch.zeros(k_sizes, dtype: k.dtype, device: k.device)], 1)
|
242
|
+
v_sizes = v.size()
|
243
|
+
v_sizes[1] = 1
|
244
|
+
v = Torch.cat([v, Torch.zeros(v_sizes, dtype: v.dtype, device: v.device)], 1)
|
245
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
246
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
247
|
+
end
|
248
|
+
|
249
|
+
attn_output_weights = Torch.bmm(q, k.transpose(1, 2))
|
250
|
+
raise ArgumentError unless attn_output_weights.size() == [bsz * num_heads, tgt_len, src_len]
|
251
|
+
|
252
|
+
if attn_mask
|
253
|
+
if attn_mask.dtype == Torch.bool
|
254
|
+
attn_output_weights.masked_fill!(attn_mask, -1.0/0.0)
|
255
|
+
else
|
256
|
+
attn_output_weights += attn_mask
|
257
|
+
end
|
258
|
+
end
|
259
|
+
|
260
|
+
|
261
|
+
if key_padding_mask
|
262
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
263
|
+
attn_output_weights = attn_output_weights.masked_fill(
|
264
|
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
265
|
+
-1.0/0.0
|
266
|
+
)
|
267
|
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
268
|
+
end
|
269
|
+
|
270
|
+
attn_output_weights = softmax(
|
271
|
+
attn_output_weights, dim: -1)
|
272
|
+
attn_output_weights = dropout(attn_output_weights, p: dropout_p, training: training)
|
273
|
+
|
274
|
+
attn_output = Torch.bmm(attn_output_weights, v)
|
275
|
+
raise ArgumentError unless attn_output.size() == [bsz * num_heads, tgt_len, head_dim]
|
276
|
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
277
|
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
278
|
+
|
279
|
+
if need_weights
|
280
|
+
# average attention weights over heads
|
281
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
282
|
+
return attn_output, attn_output_weights.sum(1) / num_heads
|
283
|
+
else
|
284
|
+
return attn_output, nil
|
285
|
+
end
|
286
|
+
end
|
287
|
+
end
|
288
|
+
end
|