secryst 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.adoc +103 -0
- data/lib/secryst-trainer.rb +8 -0
- data/lib/secryst.rb +11 -0
- data/lib/secryst/clip_grad_norm.rb +25 -0
- data/lib/secryst/multi_head_attention_forward.rb +288 -0
- data/lib/secryst/multihead_attention.rb +156 -0
- data/lib/secryst/trainer.rb +235 -0
- data/lib/secryst/transformer.rb +382 -0
- data/lib/secryst/translator.rb +51 -0
- data/lib/secryst/version.rb +3 -0
- data/lib/secryst/vocab.rb +88 -0
- metadata +95 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 919071d9eb29b220b762212fafe8f7f7ace0f0605cde7c0402dbbe8eac2584c6
|
4
|
+
data.tar.gz: de1b20e79c10d14261c71e0043e82a196b38b531da2fcdec0db44bcb992c7239
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 94bfb8aff36218341cd5baa3bcb0eb48e5be147c54f4a8b072e25a82ae027c07bacc7769fafb7f657cff9f638f9a0a4a21385108f0a45384bdf1a7796c74ea67
|
7
|
+
data.tar.gz: a1fee058281400822908c30a2d5fcf27ceb320dfed8d78878986460323d4028e78fa89db1579146870864fd2104238ee038e9f7468710863ed16ab6c8652ce03
|
data/README.adoc
ADDED
@@ -0,0 +1,103 @@
|
|
1
|
+
= Secryst
|
2
|
+
|
3
|
+
image:https://github.com/secryst/secryst/workflows/test/badge.svg["Build status", link="https://github.com/secryst/secryst/actions?workflow=test"]
|
4
|
+
|
5
|
+
== Purpose
|
6
|
+
|
7
|
+
A seq2seq transformer suited for transliteration. Written in Ruby.
|
8
|
+
|
9
|
+
Secryst was originally built for the
|
10
|
+
https://www.interscript.com[Interscript project]
|
11
|
+
(https://github.com/secryst/secryst[at GitHub]).
|
12
|
+
|
13
|
+
The goal is to allow:
|
14
|
+
|
15
|
+
* Developers to train models and provide the trained model to users. In order to to train models, raw computing and their bindings can be used, e.g. OpenCL.
|
16
|
+
|
17
|
+
* Users of the library in Ruby who only want to "use" the trained models should not require special bindings to run.
|
18
|
+
|
19
|
+
|
20
|
+
== Status
|
21
|
+
|
22
|
+
Currently Secryst works with the Khmer Romanization system as cited below.
|
23
|
+
|
24
|
+
|
25
|
+
== Prerequisites
|
26
|
+
|
27
|
+
* Ruby 2.7 (*MUST* - 2.6 does not work with the latest torch-rb)
|
28
|
+
|
29
|
+
* `libtorch` (1.6.0)
|
30
|
+
* `fftw`
|
31
|
+
* `gsl`
|
32
|
+
* `lapack`
|
33
|
+
* `openblas`
|
34
|
+
|
35
|
+
|
36
|
+
On Ubuntu:
|
37
|
+
|
38
|
+
[source,sh]
|
39
|
+
----
|
40
|
+
$ sudo apt-get -y install libfftw3-dev libgsl-dev libopenblas-dev \
|
41
|
+
liblapack-dev liblapacke-dev unzip automake make gcc g++ \
|
42
|
+
libtorch libtorch-dev
|
43
|
+
$ wget https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
|
44
|
+
$ unzip libtorch-cxx11-abi-shared-with-deps-1.6.0.zip
|
45
|
+
|
46
|
+
$ gem install bundler -v "~> 2"
|
47
|
+
$ bundle config build.torch-rb \
|
48
|
+
--with-torch-dir=$(pwd)/libtorch
|
49
|
+
|
50
|
+
$ bundle install
|
51
|
+
----
|
52
|
+
|
53
|
+
|
54
|
+
On macOS:
|
55
|
+
|
56
|
+
[source,sh]
|
57
|
+
----
|
58
|
+
$ brew install libtorch gsl lapack openblas fftw automake gcc
|
59
|
+
|
60
|
+
$ gem install bundler -v "~> 2"
|
61
|
+
$ bundle config build.numo-linalg \
|
62
|
+
--with-openblas-dir=/usr/local/opt/openblas \
|
63
|
+
--with-lapack-lib=/usr/local/opt/lapack
|
64
|
+
|
65
|
+
$ bundle install
|
66
|
+
----
|
67
|
+
|
68
|
+
|
69
|
+
NOTE: (for macOS)
|
70
|
+
If you mistakenly installed `numo-linalg` without the above configuration
|
71
|
+
options, please uninstall it with these steps and configure the bundle as
|
72
|
+
described above:
|
73
|
+
|
74
|
+
[source,sh]
|
75
|
+
----
|
76
|
+
$ bundle exec gem uninstall numo-linalg
|
77
|
+
----
|
78
|
+
|
79
|
+
|
80
|
+
|
81
|
+
== References
|
82
|
+
|
83
|
+
Secryst is built on the transformer model with architecture
|
84
|
+
based on:
|
85
|
+
|
86
|
+
* Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
87
|
+
Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin.
|
88
|
+
Attention is all you need. 2017. In:
|
89
|
+
_Advances in Neural Information Processing Systems_, pages 6000-6010.
|
90
|
+
|
91
|
+
The sample transliteration system implemented is the Khmer system:
|
92
|
+
|
93
|
+
* https://viblo.asia/p/nlp-khmer-word-segmentation-YWOZrgNNlQ0
|
94
|
+
* https://viblo.asia/p/nlp-khmer-romanization-using-seq2seq-m68Z07OQKkG
|
95
|
+
|
96
|
+
|
97
|
+
== Origin of name
|
98
|
+
|
99
|
+
Scrying is the practice of peering into a crystal sphere for fortune telling.
|
100
|
+
The purpose of `seq2seq` is nearly like scrying: looking into a crystal sphere
|
101
|
+
for some machine-learning magic to happen.
|
102
|
+
|
103
|
+
"`Secryst`" comes from the combination of "`seq2seq`" + "`crystal`" + "`scrying`".
|
data/lib/secryst.rb
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
# ported from https://pytorch.org/docs/master/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
|
2
|
+
|
3
|
+
module Secryst
|
4
|
+
class ClipGradNorm < Torch::NN::F
|
5
|
+
def self.clip_grad_norm(parameters, max_norm:, norm_type:2)
|
6
|
+
parameters = parameters.select {|p| p.grad }
|
7
|
+
max_norm = max_norm.to_f
|
8
|
+
if parameters.length == 0
|
9
|
+
return Torch.tensor(0.0)
|
10
|
+
end
|
11
|
+
device = parameters[0].grad.device
|
12
|
+
if norm_type == Float::INFINITY
|
13
|
+
# ... TODO
|
14
|
+
else
|
15
|
+
total_norm = Numo::Linalg.norm(Numo::NArray.concatenate(parameters.map {|p| Numo::Linalg.norm(p.grad.detach.numo, norm_type)}), norm_type)
|
16
|
+
end
|
17
|
+
clip_coef = max_norm / (total_norm + 1e-6)
|
18
|
+
if clip_coef < 1
|
19
|
+
parameters.each {|p| p.grad = p.grad.detach * clip_coef}
|
20
|
+
end
|
21
|
+
|
22
|
+
return total_norm
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,288 @@
|
|
1
|
+
module Secryst
|
2
|
+
class MultiHeadAttentionForward < Torch::NN::F
|
3
|
+
# Args:
|
4
|
+
# query, key, value: map a query and a set of key-value pairs to an output.
|
5
|
+
# See "Attention Is All You Need" for more details.
|
6
|
+
# embed_dim_to_check: total dimension of the model.
|
7
|
+
# num_heads: parallel attention heads.
|
8
|
+
# in_proj_weight, in_proj_bias: input projection weight and bias.
|
9
|
+
# bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
10
|
+
# add_zero_attn: add a new batch of zeros to the key and
|
11
|
+
# value sequences at dim=1.
|
12
|
+
# dropout_p: probability of an element to be zeroed.
|
13
|
+
# out_proj_weight, out_proj_bias: the output projection weight and bias.
|
14
|
+
# training: apply dropout if is ``true``.
|
15
|
+
# key_padding_mask: if provided, specified padding elements in the key will
|
16
|
+
# be ignored by the attention. This is a binary mask. When the value is true,
|
17
|
+
# the corresponding value on the attention layer will be filled with -inf.
|
18
|
+
# need_weights: output attn_output_weights.
|
19
|
+
# attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
20
|
+
# the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
21
|
+
# use_separate_proj_weight: the function accept the proj. weights for query, key,
|
22
|
+
# and value in different forms. If false, in_proj_weight will be used, which is
|
23
|
+
# a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
24
|
+
# q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
25
|
+
# static_k, static_v: static key and value used for attention operators.
|
26
|
+
# Shape:
|
27
|
+
# Inputs:
|
28
|
+
# - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
29
|
+
# the embedding dimension.
|
30
|
+
# - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
31
|
+
# the embedding dimension.
|
32
|
+
# - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
33
|
+
# the embedding dimension.
|
34
|
+
# - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
35
|
+
# If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
36
|
+
# will be unchanged. If a BoolTensor is provided, the positions with the
|
37
|
+
# value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
|
38
|
+
# - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
39
|
+
# 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
40
|
+
# S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
41
|
+
# positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
42
|
+
# while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
|
43
|
+
# are not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
|
44
|
+
# is provided, it will be added to the attention weight.
|
45
|
+
# - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
46
|
+
# N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
47
|
+
# - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
48
|
+
# N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
49
|
+
# Outputs:
|
50
|
+
# - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
51
|
+
# E is the embedding dimension.
|
52
|
+
# - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
53
|
+
# L is the target sequence length, S is the source sequence length.
|
54
|
+
def self.multi_head_attention_forward(query,
|
55
|
+
key,
|
56
|
+
value,
|
57
|
+
embed_dim_to_check,
|
58
|
+
num_heads,
|
59
|
+
in_proj_weight,
|
60
|
+
in_proj_bias,
|
61
|
+
bias_k,
|
62
|
+
bias_v,
|
63
|
+
add_zero_attn,
|
64
|
+
dropout_p,
|
65
|
+
out_proj_weight,
|
66
|
+
out_proj_bias,
|
67
|
+
training: true,
|
68
|
+
key_padding_mask: nil,
|
69
|
+
need_weights: true,
|
70
|
+
attn_mask: nil,
|
71
|
+
use_separate_proj_weight: false,
|
72
|
+
q_proj_weight: nil,
|
73
|
+
k_proj_weight: nil,
|
74
|
+
v_proj_weight: nil,
|
75
|
+
static_k: nil,
|
76
|
+
static_v: nil)
|
77
|
+
tgt_len, bsz, embed_dim = query.size()
|
78
|
+
raise ArgumentError if embed_dim != embed_dim_to_check
|
79
|
+
# allow MHA to have different sizes for the feature dimension
|
80
|
+
raise ArgumentError if key.size(0) != value.size(0) or key.size(1) != value.size(1)
|
81
|
+
|
82
|
+
head_dim = embed_dim / num_heads
|
83
|
+
raise ArgumentError, "embed_dim must be divisible by num_heads" if head_dim * num_heads != embed_dim
|
84
|
+
scaling = head_dim.to_f ** -0.5
|
85
|
+
|
86
|
+
if !use_separate_proj_weight
|
87
|
+
if Torch.equal(query, key) && Torch.equal(key, value)
|
88
|
+
# self-attention
|
89
|
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, -1)
|
90
|
+
|
91
|
+
elsif Torch.equal(key, value)
|
92
|
+
# encoder-decoder attention
|
93
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
94
|
+
_b = in_proj_bias
|
95
|
+
_start = 0
|
96
|
+
_end = embed_dim
|
97
|
+
_w = in_proj_weight.slice(0, _start, _end) # NOTE: inc-trspl
|
98
|
+
if _b
|
99
|
+
_b = _b.slice(0, _start, _end)
|
100
|
+
end
|
101
|
+
q = linear(query, _w, _b)
|
102
|
+
|
103
|
+
if !key
|
104
|
+
raise ArgumentError if value
|
105
|
+
k = nil
|
106
|
+
v = nil
|
107
|
+
else
|
108
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
109
|
+
_b = in_proj_bias
|
110
|
+
_start = embed_dim
|
111
|
+
_end = nil
|
112
|
+
_w = in_proj_weight.slice(0, _start)
|
113
|
+
if _b
|
114
|
+
_b = _b.slice(0, _start)
|
115
|
+
end
|
116
|
+
k, v = linear(key, _w, _b).chunk(2, -1)
|
117
|
+
end
|
118
|
+
|
119
|
+
else
|
120
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
121
|
+
_b = in_proj_bias
|
122
|
+
_start = 0
|
123
|
+
_end = embed_dim
|
124
|
+
_w = in_proj_weight.slice(0, _start, _end)
|
125
|
+
if _b
|
126
|
+
_b = _b.slice(0, _start, _end)
|
127
|
+
end
|
128
|
+
q = linear(query, _w, _b)
|
129
|
+
|
130
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
131
|
+
_b = in_proj_bias
|
132
|
+
_start = embed_dim
|
133
|
+
_end = embed_dim * 2
|
134
|
+
_w = in_proj_weight.slice(0, _start, _end)
|
135
|
+
if _b
|
136
|
+
_b = _b.slice(0, _start, _end)
|
137
|
+
end
|
138
|
+
k = linear(key, _w, _b)
|
139
|
+
|
140
|
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
141
|
+
_b = in_proj_bias
|
142
|
+
_start = embed_dim * 2
|
143
|
+
_end = nil
|
144
|
+
_w = in_proj_weight.slice(0, _start)
|
145
|
+
if _b
|
146
|
+
_b = _b.slice(0, _start)
|
147
|
+
end
|
148
|
+
v = linear(value, _w, _b)
|
149
|
+
end
|
150
|
+
else
|
151
|
+
q_proj_weight_non_opt = q_proj_weight
|
152
|
+
len1, len2 = q_proj_weight_non_opt.size()
|
153
|
+
raise ArgumentError if len1 != embed_dim || len2 != query.size(-1)
|
154
|
+
|
155
|
+
k_proj_weight_non_opt = k_proj_weight
|
156
|
+
len1, len2 = k_proj_weight_non_opt.size()
|
157
|
+
raise ArgumentError if len1 != embed_dim || len2 != key.size(-1)
|
158
|
+
|
159
|
+
v_proj_weight_non_opt = v_proj_weight
|
160
|
+
len1, len2 = v_proj_weight_non_opt.size()
|
161
|
+
raise ArgumentError if len1 != embed_dim || len2 != value.size(-1)
|
162
|
+
|
163
|
+
if in_proj_bias
|
164
|
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias.slice(0,0,embed_dim))
|
165
|
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim, embed_dim * 2))
|
166
|
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim * 2))
|
167
|
+
else
|
168
|
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
169
|
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
170
|
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
171
|
+
end
|
172
|
+
end
|
173
|
+
q = q * scaling
|
174
|
+
|
175
|
+
if attn_mask
|
176
|
+
raise ArgumentError, 'Only float, byte, and bool types are supported for attn_mask, not %s' % attn_mask.dtype unless attn_mask.dtype == Torch.float32 || attn_mask.dtype == Torch.float64 || attn_mask.dtype == Torch.float16 || attn_mask.dtype == Torch.uint8 || attn_mask.dtype == Torch.bool
|
177
|
+
if attn_mask.dtype == Torch.uint8
|
178
|
+
puts "Byte tensor for attn_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead."
|
179
|
+
attn_mask = attn_mask.to(Torch.bool)
|
180
|
+
end
|
181
|
+
|
182
|
+
if attn_mask.dim() == 2
|
183
|
+
attn_mask = attn_mask.unsqueeze(0)
|
184
|
+
raise ArgumentError, 'The size of the 2D attn_mask is not correct.' if attn_mask.size() != [1, query.size(0), key.size(0)]
|
185
|
+
elsif attn_mask.dim() == 3
|
186
|
+
raise ArgumentError, 'The size of the 3D attn_mask is not correct.' if attn_mask.size() != [bsz * num_heads, query.size(0), key.size(0)]
|
187
|
+
else
|
188
|
+
raise ArgumentError, "attn_mask's dimension %s is not supported" % attn_mask.dim()
|
189
|
+
end
|
190
|
+
# attn_mask's dim is 3 now.
|
191
|
+
end
|
192
|
+
|
193
|
+
# convert ByteTensor key_padding_mask to bool
|
194
|
+
if key_padding_mask && key_padding_mask.dtype == Torch.uint8
|
195
|
+
puts("Byte tensor for key_padding_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead.")
|
196
|
+
key_padding_mask = key_padding_mask.to(Torch.bool)
|
197
|
+
end
|
198
|
+
|
199
|
+
if bias_k && bias_v
|
200
|
+
if !static_k && !static_v
|
201
|
+
k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
202
|
+
v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
203
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
204
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
205
|
+
else
|
206
|
+
raise ArgumentError, "bias cannot be added to static key." unless !static_k
|
207
|
+
raise ArgumentError, "bias cannot be added to static value." unless !static_v
|
208
|
+
end
|
209
|
+
else
|
210
|
+
raise ArgumentError unless !bias_k
|
211
|
+
raise ArgumentError unless !bias_v
|
212
|
+
end
|
213
|
+
|
214
|
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
215
|
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if k
|
216
|
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v
|
217
|
+
|
218
|
+
if static_k
|
219
|
+
raise ArgumentError unless static_k.size(0) == bsz * num_heads
|
220
|
+
raise ArgumentError unless static_k.size(2) == head_dim
|
221
|
+
k = static_k
|
222
|
+
end
|
223
|
+
|
224
|
+
if static_v
|
225
|
+
raise ArgumentError unless static_v.size(0) == bsz * num_heads
|
226
|
+
raise ArgumentError unless static_v.size(2) == head_dim
|
227
|
+
v = static_v
|
228
|
+
end
|
229
|
+
|
230
|
+
src_len = k.size(1)
|
231
|
+
|
232
|
+
if key_padding_mask
|
233
|
+
raise ArgumentError unless key_padding_mask.size(0) == bsz
|
234
|
+
raise ArgumentError unless key_padding_mask.size(1) == src_len
|
235
|
+
end
|
236
|
+
|
237
|
+
if add_zero_attn
|
238
|
+
src_len += 1
|
239
|
+
k_sizes = k.size()
|
240
|
+
k_sizes[1] = 1
|
241
|
+
k = Torch.cat([k, Torch.zeros(k_sizes, dtype: k.dtype, device: k.device)], 1)
|
242
|
+
v_sizes = v.size()
|
243
|
+
v_sizes[1] = 1
|
244
|
+
v = Torch.cat([v, Torch.zeros(v_sizes, dtype: v.dtype, device: v.device)], 1)
|
245
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
246
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
247
|
+
end
|
248
|
+
|
249
|
+
attn_output_weights = Torch.bmm(q, k.transpose(1, 2))
|
250
|
+
raise ArgumentError unless attn_output_weights.size() == [bsz * num_heads, tgt_len, src_len]
|
251
|
+
|
252
|
+
if attn_mask
|
253
|
+
if attn_mask.dtype == Torch.bool
|
254
|
+
attn_output_weights.masked_fill!(attn_mask, -1.0/0.0)
|
255
|
+
else
|
256
|
+
attn_output_weights += attn_mask
|
257
|
+
end
|
258
|
+
end
|
259
|
+
|
260
|
+
|
261
|
+
if key_padding_mask
|
262
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
263
|
+
attn_output_weights = attn_output_weights.masked_fill(
|
264
|
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
265
|
+
-1.0/0.0
|
266
|
+
)
|
267
|
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
268
|
+
end
|
269
|
+
|
270
|
+
attn_output_weights = softmax(
|
271
|
+
attn_output_weights, dim: -1)
|
272
|
+
attn_output_weights = dropout(attn_output_weights, p: dropout_p, training: training)
|
273
|
+
|
274
|
+
attn_output = Torch.bmm(attn_output_weights, v)
|
275
|
+
raise ArgumentError unless attn_output.size() == [bsz * num_heads, tgt_len, head_dim]
|
276
|
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
277
|
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
278
|
+
|
279
|
+
if need_weights
|
280
|
+
# average attention weights over heads
|
281
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
282
|
+
return attn_output, attn_output_weights.sum(1) / num_heads
|
283
|
+
else
|
284
|
+
return attn_output, nil
|
285
|
+
end
|
286
|
+
end
|
287
|
+
end
|
288
|
+
end
|