secryst 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/README.adoc +103 -0
- data/lib/secryst-trainer.rb +8 -0
- data/lib/secryst.rb +11 -0
- data/lib/secryst/clip_grad_norm.rb +25 -0
- data/lib/secryst/multi_head_attention_forward.rb +288 -0
- data/lib/secryst/multihead_attention.rb +156 -0
- data/lib/secryst/trainer.rb +235 -0
- data/lib/secryst/transformer.rb +382 -0
- data/lib/secryst/translator.rb +51 -0
- data/lib/secryst/version.rb +3 -0
- data/lib/secryst/vocab.rb +88 -0
- metadata +95 -0
@@ -0,0 +1,156 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/4ae832e1060c72cb89de1d9693629783dbe0c9a6/torch/csrc/api/include/torch/nn/functional/activation.h
|
2
|
+
|
3
|
+
require_relative 'multi_head_attention_forward'
|
4
|
+
module Secryst
|
5
|
+
class MultiheadAttention < Torch::NN::Module
|
6
|
+
# Allows the model to jointly attend to information
|
7
|
+
# from different representation subspaces.
|
8
|
+
# See reference: Attention Is All You Need
|
9
|
+
# .. math::
|
10
|
+
# \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
11
|
+
# \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
12
|
+
# Args:
|
13
|
+
# embed_dim: total dimension of the model.
|
14
|
+
# num_heads: parallel attention heads.
|
15
|
+
# dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
16
|
+
# bias: add bias as module parameter. Default: true.
|
17
|
+
# add_bias_kv: add bias to the key and value sequences at dim=0.
|
18
|
+
# add_zero_attn: add a new batch of zeros to the key and
|
19
|
+
# value sequences at dim=1.
|
20
|
+
# kdim: total number of features in key. Default: nil.
|
21
|
+
# vdim: total number of features in value. Default: nil.
|
22
|
+
# Note: if kdim and vdim are nil, they will be set to embed_dim such that
|
23
|
+
# query, key, and value have the same number of features.
|
24
|
+
# Examples::
|
25
|
+
# >>> multihead_attn = MultiheadAttention.new(embed_dim: embed_dim, num_heads: num_heads)
|
26
|
+
# >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
27
|
+
# bias_k: Optional[Torch::Tensor]
|
28
|
+
# bias_v: Optional[Torch::Tensor]
|
29
|
+
|
30
|
+
def initialize(embed_dim, num_heads, dropout:0.0, bias: true, add_bias_kv: false, add_zero_attn: false, kdim: nil, vdim: nil)
|
31
|
+
super()
|
32
|
+
@embed_dim = embed_dim
|
33
|
+
@kdim = kdim || embed_dim
|
34
|
+
@vdim = vdim || embed_dim
|
35
|
+
@_qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
|
36
|
+
|
37
|
+
@num_heads = num_heads
|
38
|
+
@dropout = dropout
|
39
|
+
@head_dim = embed_dim / num_heads
|
40
|
+
raise ArgumentError, "embed_dim must be divisible by num_heads" if @head_dim * num_heads != @embed_dim
|
41
|
+
|
42
|
+
if !@_qkv_same_embed_dim
|
43
|
+
@q_proj_weight = Torch::NN::Parameter.new(Torch::Tensor.new(embed_dim, embed_dim))
|
44
|
+
@k_proj_weight = Torch::NN::Parameter.new(Torch::Tensor.new(embed_dim, @kdim))
|
45
|
+
@v_proj_weight = Torch::NN::Parameter.new(Torch::Tensor.new(embed_dim, @vdim))
|
46
|
+
register_parameter('in_proj_weight', nil)
|
47
|
+
else
|
48
|
+
@in_proj_weight = Torch::NN::Parameter.new(Torch.empty(3 * embed_dim, embed_dim))
|
49
|
+
register_parameter('q_proj_weight', nil)
|
50
|
+
register_parameter('k_proj_weight', nil)
|
51
|
+
register_parameter('v_proj_weight', nil)
|
52
|
+
end
|
53
|
+
|
54
|
+
if bias
|
55
|
+
@in_proj_bias = Torch::NN::Parameter.new(Torch.empty(3 * embed_dim))
|
56
|
+
else
|
57
|
+
register_parameter('in_proj_bias', nil)
|
58
|
+
end
|
59
|
+
@out_proj = Torch::NN::Linear.new(embed_dim, embed_dim)
|
60
|
+
|
61
|
+
if add_bias_kv
|
62
|
+
@bias_k = Torch::NN::Parameter.new(Torch.empty(1, 1, embed_dim))
|
63
|
+
@bias_v = Torch::NN::Parameter.new(Torch.empty(1, 1, embed_dim))
|
64
|
+
else
|
65
|
+
@bias_k = @bias_v = nil
|
66
|
+
end
|
67
|
+
|
68
|
+
@add_zero_attn = add_zero_attn
|
69
|
+
|
70
|
+
_reset_parameters
|
71
|
+
end
|
72
|
+
|
73
|
+
def _reset_parameters
|
74
|
+
if @_qkv_same_embed_dim
|
75
|
+
Torch::NN::Init.xavier_uniform!(@in_proj_weight)
|
76
|
+
else
|
77
|
+
Torch::NN::Init.xavier_uniform!(@q_proj_weight)
|
78
|
+
Torch::NN::Init.xavier_uniform!(@k_proj_weight)
|
79
|
+
Torch::NN::Init.xavier_uniform!(@v_proj_weight)
|
80
|
+
end
|
81
|
+
|
82
|
+
if @in_proj_bias
|
83
|
+
Torch::NN::Init.constant!(@in_proj_bias, 0.0)
|
84
|
+
Torch::NN::Init.constant!(@out_proj.bias, 0.0)
|
85
|
+
end
|
86
|
+
|
87
|
+
if @bias_k
|
88
|
+
Torch::NN::Init.xavier_normal!(@bias_k)
|
89
|
+
end
|
90
|
+
|
91
|
+
if @bias_v
|
92
|
+
Torch::NN::Init.xavier_normal!(@bias_v)
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
# Args:
|
97
|
+
# query, key, value: map a query and a set of key-value pairs to an output.
|
98
|
+
# See "Attention Is All You Need" for more details.
|
99
|
+
# key_padding_mask: if provided, specified padding elements in the key will
|
100
|
+
# be ignored by the attention. When given a binary mask and a value is true,
|
101
|
+
# the corresponding value on the attention layer will be ignored. When given
|
102
|
+
# a byte mask and a value is non-zero, the corresponding value on the attention
|
103
|
+
# layer will be ignored
|
104
|
+
# need_weights: output attn_output_weights.
|
105
|
+
# attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
106
|
+
# the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
107
|
+
# Shape:
|
108
|
+
# - Inputs:
|
109
|
+
# - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
110
|
+
# the embedding dimension.
|
111
|
+
# - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
112
|
+
# the embedding dimension.
|
113
|
+
# - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
114
|
+
# the embedding dimension.
|
115
|
+
# - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
116
|
+
# If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
117
|
+
# with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
118
|
+
# value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
|
119
|
+
# - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
120
|
+
# 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
121
|
+
# S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
122
|
+
# positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
123
|
+
# while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
|
124
|
+
# is not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
|
125
|
+
# is provided, it will be added to the attention weight.
|
126
|
+
# - Outputs:
|
127
|
+
# - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
128
|
+
# E is the embedding dimension.
|
129
|
+
# - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
130
|
+
# L is the target sequence length, S is the source sequence length.
|
131
|
+
def forward(query, key, value, key_padding_mask:nil,
|
132
|
+
need_weights:true, attn_mask:nil)
|
133
|
+
if !@_qkv_same_embed_dim
|
134
|
+
return Secryst::MultiHeadAttentionForward.multi_head_attention_forward(
|
135
|
+
query, key, value, @embed_dim, @num_heads,
|
136
|
+
@in_proj_weight, @in_proj_bias,
|
137
|
+
@bias_k, @bias_v, @add_zero_attn,
|
138
|
+
@dropout, @out_proj.weight, @out_proj.bias,
|
139
|
+
training: @training,
|
140
|
+
key_padding_mask: key_padding_mask, need_weights: need_weights,
|
141
|
+
attn_mask: attn_mask, use_separate_proj_weight: true,
|
142
|
+
q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight,
|
143
|
+
v_proj_weight: @v_proj_weight)
|
144
|
+
else
|
145
|
+
return Secryst::MultiHeadAttentionForward.multi_head_attention_forward(
|
146
|
+
query, key, value, @embed_dim, @num_heads,
|
147
|
+
@in_proj_weight, @in_proj_bias,
|
148
|
+
@bias_k, @bias_v, @add_zero_attn,
|
149
|
+
@dropout, @out_proj.weight, @out_proj.bias,
|
150
|
+
training: @training,
|
151
|
+
key_padding_mask: key_padding_mask, need_weights: need_weights,
|
152
|
+
attn_mask: attn_mask)
|
153
|
+
end
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
@@ -0,0 +1,235 @@
|
|
1
|
+
module Secryst
|
2
|
+
class Trainer
|
3
|
+
|
4
|
+
def initialize(
|
5
|
+
model:,
|
6
|
+
batch_size:,
|
7
|
+
lr:,
|
8
|
+
data_input:,
|
9
|
+
data_target:,
|
10
|
+
hyperparameters:,
|
11
|
+
max_epochs: nil,
|
12
|
+
log_interval: 1,
|
13
|
+
checkpoint_every:,
|
14
|
+
checkpoint_dir:,
|
15
|
+
scheduler_step_size:,
|
16
|
+
gamma:
|
17
|
+
)
|
18
|
+
@data_input = File.readlines(data_input, chomp: true)
|
19
|
+
@data_target = File.readlines(data_target, chomp: true)
|
20
|
+
|
21
|
+
@device = "cpu"
|
22
|
+
@lr = lr
|
23
|
+
@scheduler_step_size = scheduler_step_size
|
24
|
+
@gamma = gamma
|
25
|
+
@batch_size = batch_size
|
26
|
+
@model_name = model
|
27
|
+
@max_epochs = max_epochs
|
28
|
+
@log_interval = log_interval
|
29
|
+
@checkpoint_every = checkpoint_every
|
30
|
+
@checkpoint_dir = checkpoint_dir
|
31
|
+
FileUtils.mkdir_p(@checkpoint_dir)
|
32
|
+
generate_vocabs_and_data
|
33
|
+
save_vocabs
|
34
|
+
|
35
|
+
case model
|
36
|
+
when 'transformer'
|
37
|
+
@model = Secryst::Transformer.new(hyperparameters.merge({
|
38
|
+
input_vocab_size: @input_vocab.length,
|
39
|
+
target_vocab_size: @target_vocab.length,
|
40
|
+
}))
|
41
|
+
else
|
42
|
+
raise ArgumentError, 'Only transformer model is currently supported'
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def train
|
47
|
+
best_model = nil
|
48
|
+
best_val_loss = 1.0/0.0 # infinity
|
49
|
+
|
50
|
+
return unless @model_name == 'transformer'
|
51
|
+
|
52
|
+
criterion = Torch::NN::CrossEntropyLoss.new(ignore_index: index_of('<pad>')).to(@device)
|
53
|
+
optimizer = Torch::Optim::SGD.new(@model.parameters, lr: @lr)
|
54
|
+
scheduler = Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: @scheduler_step_size, gamma: @gamma)
|
55
|
+
|
56
|
+
total_loss = 0.0
|
57
|
+
start_time = Time.now
|
58
|
+
ntokens = @target_vocab.length
|
59
|
+
epoch = 0
|
60
|
+
|
61
|
+
loop do
|
62
|
+
epoch_start_time = Time.now
|
63
|
+
@model.train
|
64
|
+
@train_data.each.with_index do |batch, i|
|
65
|
+
inputs, targets, decoder_inputs, src_mask, tgt_mask, memory_mask = batch
|
66
|
+
inputs = Torch.tensor(inputs).t
|
67
|
+
decoder_inputs = Torch.tensor(decoder_inputs).t
|
68
|
+
targets = Torch.tensor(targets).t
|
69
|
+
src_key_padding_mask = inputs.t.eq(1)
|
70
|
+
tgt_key_padding_mask = decoder_inputs.t.eq(1)
|
71
|
+
|
72
|
+
optimizer.zero_grad
|
73
|
+
opts = {
|
74
|
+
# src_mask: src_mask,
|
75
|
+
tgt_mask: tgt_mask,
|
76
|
+
# memory_mask: memory_mask,
|
77
|
+
src_key_padding_mask: src_key_padding_mask,
|
78
|
+
tgt_key_padding_mask: tgt_key_padding_mask,
|
79
|
+
memory_key_padding_mask: src_key_padding_mask,
|
80
|
+
}
|
81
|
+
output = @model.call(inputs, decoder_inputs, opts)
|
82
|
+
loss = criterion.call(output.transpose(0,1).reshape(-1, ntokens), targets.t.view(-1))
|
83
|
+
loss.backward
|
84
|
+
ClipGradNorm.clip_grad_norm(@model.parameters, max_norm: 0.5)
|
85
|
+
optimizer.step
|
86
|
+
|
87
|
+
# puts "i[#{i}] loss: #{loss}"
|
88
|
+
total_loss += loss.item()
|
89
|
+
if ( (i + 1) % @log_interval == 0 )
|
90
|
+
cur_loss = total_loss / @log_interval
|
91
|
+
elapsed = Time.now - start_time
|
92
|
+
puts "| epoch #{epoch} | #{i + 1}/#{@train_data.length} batches | "\
|
93
|
+
"lr #{scheduler.get_lr()[0].round(4)} | ms/batch #{(1000*elapsed.to_f / @log_interval).round} | "\
|
94
|
+
"loss #{cur_loss.round(5)} | ppl #{Math.exp(cur_loss).round(5)}"
|
95
|
+
total_loss = 0
|
96
|
+
start_time = Time.now
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
if epoch > 0 && epoch % @checkpoint_every == 0
|
101
|
+
puts ">> Saving checkpoint '#{@checkpoint_dir}/checkpoint-#{epoch}.pth'"
|
102
|
+
Torch.save(@model.state_dict, "#{@checkpoint_dir}/checkpoint-#{epoch}.pth")
|
103
|
+
end
|
104
|
+
|
105
|
+
# Evaluate
|
106
|
+
@model.eval()
|
107
|
+
total_loss = 0.0
|
108
|
+
Torch.no_grad do
|
109
|
+
@eval_data.each.with_index do |batch, i|
|
110
|
+
inputs, targets, decoder_inputs, src_mask, tgt_mask, memory_mask = batch
|
111
|
+
inputs = Torch.tensor(inputs).t
|
112
|
+
decoder_inputs = Torch.tensor(decoder_inputs).t
|
113
|
+
targets = Torch.tensor(targets).t
|
114
|
+
src_key_padding_mask = inputs.t.eq(1)
|
115
|
+
tgt_key_padding_mask = decoder_inputs.t.eq(1)
|
116
|
+
|
117
|
+
opts = {
|
118
|
+
# src_mask: src_mask,
|
119
|
+
tgt_mask: tgt_mask,
|
120
|
+
# memory_mask: memory_mask,
|
121
|
+
src_key_padding_mask: src_key_padding_mask,
|
122
|
+
tgt_key_padding_mask: tgt_key_padding_mask,
|
123
|
+
memory_key_padding_mask: src_key_padding_mask,
|
124
|
+
}
|
125
|
+
output = @model.call(inputs, decoder_inputs, **opts)
|
126
|
+
output_flat = output.transpose(0,1).reshape(-1, ntokens)
|
127
|
+
|
128
|
+
total_loss += criterion.call(output_flat, targets.t.view(-1)).item
|
129
|
+
end
|
130
|
+
total_loss = total_loss / @eval_data.length
|
131
|
+
puts('-' * 89)
|
132
|
+
puts "| end of epoch #{epoch} | time: #{(Time.now - epoch_start_time).round(3)}s | "\
|
133
|
+
" valid loss #{total_loss.round(5)} | valid ppl #{Math.exp(total_loss).round(5)} "
|
134
|
+
puts('-' * 89)
|
135
|
+
if total_loss < best_val_loss
|
136
|
+
best_model = @model
|
137
|
+
best_val_loss = total_loss
|
138
|
+
end
|
139
|
+
end
|
140
|
+
scheduler.step
|
141
|
+
|
142
|
+
epoch += 1
|
143
|
+
break if @max_epochs && @max_epochs < epoch
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
private
|
148
|
+
|
149
|
+
def generate_vocabs_and_data
|
150
|
+
input_texts = []
|
151
|
+
target_texts = []
|
152
|
+
input_vocab_counter = Hash.new(0)
|
153
|
+
target_vocab_counter = Hash.new(0)
|
154
|
+
|
155
|
+
@data_input.each do |input_text|
|
156
|
+
input_text.strip!
|
157
|
+
input_texts.push(input_text)
|
158
|
+
input_text.each_char do |char|
|
159
|
+
input_vocab_counter[char] += 1
|
160
|
+
end
|
161
|
+
end
|
162
|
+
|
163
|
+
@data_target.each do |target_text|
|
164
|
+
target_text.strip!
|
165
|
+
target_texts.push(target_text)
|
166
|
+
target_text.each_char do |char|
|
167
|
+
target_vocab_counter[char] += 1
|
168
|
+
end
|
169
|
+
end
|
170
|
+
|
171
|
+
@input_vocab = Vocab.new(input_vocab_counter)
|
172
|
+
@target_vocab = Vocab.new(target_vocab_counter)
|
173
|
+
|
174
|
+
# Generate train, eval, and test batches
|
175
|
+
seed = 1
|
176
|
+
zipped_texts = input_texts.zip(target_texts)
|
177
|
+
zipped_texts = zipped_texts.shuffle(random: Random.new(seed))
|
178
|
+
|
179
|
+
# train - 90%, eval - 7%, test - 3%
|
180
|
+
train_texts = zipped_texts[0..(zipped_texts.length*0.9).to_i]
|
181
|
+
eval_texts = zipped_texts[(zipped_texts.length*0.9).to_i + 1..(zipped_texts.length*0.97).to_i]
|
182
|
+
test_texts = zipped_texts[(zipped_texts.length*0.97).to_i+1..-1]
|
183
|
+
|
184
|
+
# prepare batches
|
185
|
+
@train_data = batchify(train_texts)
|
186
|
+
@eval_data = batchify(eval_texts)
|
187
|
+
@test_data = batchify(test_texts)
|
188
|
+
|
189
|
+
end
|
190
|
+
|
191
|
+
def pad(arr, length, no_eos:false, no_sos:false)
|
192
|
+
if !no_eos
|
193
|
+
arr = arr + ["<eos>"]
|
194
|
+
end
|
195
|
+
if !no_sos
|
196
|
+
arr = ["<sos>"] + arr
|
197
|
+
end
|
198
|
+
arr.fill("<pad>", arr.length...length)
|
199
|
+
end
|
200
|
+
|
201
|
+
def index_of(token)
|
202
|
+
@target_vocab.stoi[token]
|
203
|
+
end
|
204
|
+
|
205
|
+
def batchify(data)
|
206
|
+
batches = []
|
207
|
+
|
208
|
+
(1 + data.length / @batch_size).times do |i|
|
209
|
+
input_data = data[i*@batch_size, @batch_size].transpose[0]
|
210
|
+
decoder_input_data = data[i*@batch_size, @batch_size].transpose[1]
|
211
|
+
target_data = data[i*@batch_size, @batch_size].transpose[1]
|
212
|
+
max_input_seq_length = input_data.max_by(&:length).length + 2
|
213
|
+
max_target_seq_length = target_data.max_by(&:length).length + 1
|
214
|
+
src_mask = Torch.triu(Torch.ones(max_input_seq_length,max_input_seq_length)).eq(0).transpose(0,1)
|
215
|
+
tgt_mask = Torch.triu(Torch.ones(max_target_seq_length,max_target_seq_length)).eq(0).transpose(0,1)
|
216
|
+
memory_mask = Torch.triu(Torch.ones(max_input_seq_length,max_target_seq_length)).eq(0).transpose(0,1)
|
217
|
+
batches << [
|
218
|
+
input_data.map {|line| pad(line.chars, max_input_seq_length).map {|c| @input_vocab[c]} },
|
219
|
+
target_data.map {|line| pad(line.chars, max_target_seq_length, no_sos: true).map {|c| @target_vocab[c]} },
|
220
|
+
decoder_input_data.map {|line| pad(line.chars, max_target_seq_length, no_eos: true).map {|c| @target_vocab[c]} },
|
221
|
+
src_mask,
|
222
|
+
tgt_mask,
|
223
|
+
memory_mask
|
224
|
+
]
|
225
|
+
end
|
226
|
+
|
227
|
+
batches
|
228
|
+
end
|
229
|
+
|
230
|
+
def save_vocabs
|
231
|
+
File.write("#{@checkpoint_dir}/input_vocab.json", JSON.generate(@input_vocab.freqs))
|
232
|
+
File.write("#{@checkpoint_dir}/target_vocab.json", JSON.generate(@target_vocab.freqs))
|
233
|
+
end
|
234
|
+
end
|
235
|
+
end
|
@@ -0,0 +1,382 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/626e410e1dedcdb9d5a410a8827cc7a8a9fbcce1/torch/nn/modules/transformer.py
|
2
|
+
|
3
|
+
module Secryst
|
4
|
+
class Transformer < Torch::NN::Module
|
5
|
+
# A transformer model. User is able to modify the attributes as needed. The architecture
|
6
|
+
# is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
7
|
+
# Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
8
|
+
# Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
9
|
+
# Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
|
10
|
+
# model with corresponding parameters.
|
11
|
+
# Args:
|
12
|
+
# d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
13
|
+
# nhead: the number of heads in the multiheadattention models (default=8).
|
14
|
+
# num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
15
|
+
# num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
16
|
+
# dim_feedforward: the dimension of the feedforward network model (default=2048).
|
17
|
+
# dropout: the dropout value (default=0.1).
|
18
|
+
# activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
|
19
|
+
# custom_encoder: custom encoder (default=nil).
|
20
|
+
# custom_decoder: custom decoder (default=nil).
|
21
|
+
# input_vocab_size: size of vocabulary for input sequence (number of different possible tokens).
|
22
|
+
# target_vocab_size: size of vocabulary for target sequence (number of different possible tokens).
|
23
|
+
# Examples::
|
24
|
+
# >>> transformer_model = Transformer.new(nhead: 16, num_encoder_layers: 12)
|
25
|
+
# >>> src = Torch.rand((10, 32, 512))
|
26
|
+
# >>> tgt = Torch.rand((20, 32, 512))
|
27
|
+
# >>> out = transformer_model.call(src, tgt)
|
28
|
+
def initialize(d_model: 512, nhead: 8, num_encoder_layers: 6, num_decoder_layers: 6,
|
29
|
+
dim_feedforward: 2048, dropout: 0.1, activation: 'relu', custom_encoder: nil, custom_decoder: nil, input_vocab_size:, target_vocab_size:)
|
30
|
+
|
31
|
+
super()
|
32
|
+
|
33
|
+
if custom_encoder
|
34
|
+
@encoder = custom_encoder
|
35
|
+
else
|
36
|
+
encoder_layers = num_encoder_layers.times.map { TransformerEncoderLayer.new(d_model, nhead, dim_feedforward: dim_feedforward, dropout: dropout, activation: activation) }
|
37
|
+
encoder_norm = Torch::NN::LayerNorm.new(d_model)
|
38
|
+
@encoder = TransformerEncoder.new(encoder_layers, encoder_norm, d_model, input_vocab_size, dropout)
|
39
|
+
end
|
40
|
+
|
41
|
+
if custom_decoder
|
42
|
+
@decoder = custom_decoder
|
43
|
+
else
|
44
|
+
decoder_layers = num_decoder_layers.times.map { TransformerDecoderLayer.new(d_model, nhead, dim_feedforward: dim_feedforward, dropout: dropout, activation: activation) }
|
45
|
+
decoder_norm = Torch::NN::LayerNorm.new(d_model)
|
46
|
+
@decoder = TransformerDecoder.new(decoder_layers, decoder_norm, d_model, target_vocab_size, dropout)
|
47
|
+
end
|
48
|
+
|
49
|
+
@linear = Torch::NN::Linear.new(d_model, target_vocab_size)
|
50
|
+
@softmax = Torch::NN::LogSoftmax.new(dim: -1)
|
51
|
+
_reset_parameters()
|
52
|
+
|
53
|
+
@d_model = d_model
|
54
|
+
@nhead = nhead
|
55
|
+
|
56
|
+
end
|
57
|
+
|
58
|
+
# Take in and process masked source/target sequences.
|
59
|
+
# Args:
|
60
|
+
# src: the sequence to the encoder (required).
|
61
|
+
# tgt: the sequence to the decoder (required).
|
62
|
+
# src_mask: the additive mask for the src sequence (optional).
|
63
|
+
# tgt_mask: the additive mask for the tgt sequence (optional).
|
64
|
+
# memory_mask: the additive mask for the encoder output (optional).
|
65
|
+
# src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
|
66
|
+
# tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
|
67
|
+
# memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
|
68
|
+
# Shape:
|
69
|
+
# - src: :math:`(S, N, E)`.
|
70
|
+
# - tgt: :math:`(T, N, E)`.
|
71
|
+
# - src_mask: :math:`(S, S)`.
|
72
|
+
# - tgt_mask: :math:`(T, T)`.
|
73
|
+
# - memory_mask: :math:`(T, S)`.
|
74
|
+
# - src_key_padding_mask: :math:`(N, S)`.
|
75
|
+
# - tgt_key_padding_mask: :math:`(N, T)`.
|
76
|
+
# - memory_key_padding_mask: :math:`(N, S)`.
|
77
|
+
# Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
|
78
|
+
# positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
79
|
+
# while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
|
80
|
+
# are not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
|
81
|
+
# is provided, it will be added to the attention weight.
|
82
|
+
# [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
|
83
|
+
# the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
|
84
|
+
# positions will be unchanged. If a BoolTensor is provided, the positions with the
|
85
|
+
# value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
|
86
|
+
# - output: :math:`(T, N, E)`.
|
87
|
+
# Note: Due to the multi-head attention architecture in the transformer model,
|
88
|
+
# the output sequence length of a transformer is same as the input sequence
|
89
|
+
# (i.e. target) length of the decode.
|
90
|
+
# where S is the source sequence length, T is the target sequence length, N is the
|
91
|
+
# batch size, E is the feature number
|
92
|
+
# Examples:
|
93
|
+
# >>> output = transformer_model.call(src, tgt, src_mask: src_mask, tgt_mask: tgt_mask)
|
94
|
+
def forward(src, tgt, src_mask: nil, tgt_mask: nil,
|
95
|
+
memory_mask: nil, src_key_padding_mask: nil,
|
96
|
+
tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
|
97
|
+
if src.size(1) != tgt.size(1)
|
98
|
+
raise RuntimeError, "the batch number of src and tgt must be equal"
|
99
|
+
end
|
100
|
+
|
101
|
+
memory = @encoder.call(src, mask: src_mask, src_key_padding_mask: src_key_padding_mask)
|
102
|
+
output = @decoder.call(tgt, memory, tgt_mask: tgt_mask, memory_mask: memory_mask,
|
103
|
+
tgt_key_padding_mask: tgt_key_padding_mask,
|
104
|
+
memory_key_padding_mask: memory_key_padding_mask)
|
105
|
+
output = @linear.call(output)
|
106
|
+
output = @softmax.call(output)
|
107
|
+
|
108
|
+
return output
|
109
|
+
end
|
110
|
+
|
111
|
+
def _reset_parameters
|
112
|
+
parameters.each do |p|
|
113
|
+
Torch::NN::Init.xavier_uniform!(p) if p.dim > 1
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
|
119
|
+
class TransformerEncoder < Torch::NN::Module
|
120
|
+
# TransformerEncoder is a stack of N encoder layers
|
121
|
+
# Args:
|
122
|
+
# encoder_layers: an array of instances of the TransformerEncoderLayer class (required).
|
123
|
+
# norm: the layer normalization component (optional).
|
124
|
+
# d_model: the number of expected features in the encoder/decoder inputs.
|
125
|
+
# vocab_size: size of vocabulary (number of different possible tokens).
|
126
|
+
# Examples::
|
127
|
+
# >>> encoder_layers = 6.times.map {|i| TransformerEncoderLayer.new(512, 8) }
|
128
|
+
# >>> transformer_encoder = nn.TransformerEncoder(encoder_layers, nil, 512, 72, 0.1)
|
129
|
+
# >>> src = Torch.rand(10, 32, 512)
|
130
|
+
# >>> out = transformer_encoder.call(src)
|
131
|
+
def initialize(encoder_layers, norm=nil, d_model, vocab_size, dropout)
|
132
|
+
super()
|
133
|
+
@d_model = d_model
|
134
|
+
encoder_layers.each.with_index do |l, i|
|
135
|
+
instance_variable_set("@layer#{i}", l)
|
136
|
+
end
|
137
|
+
@layers = encoder_layers.length.times.map {|i| instance_variable_get("@layer#{i}") }
|
138
|
+
@num_layers = encoder_layers.length
|
139
|
+
@embedding = Torch::NN::Embedding.new(vocab_size, d_model)
|
140
|
+
@pos_encoder = PositionalEncoding.new(d_model, dropout: dropout)
|
141
|
+
@norm = norm
|
142
|
+
end
|
143
|
+
|
144
|
+
# Pass the input through the encoder layers in turn.
|
145
|
+
# Args:
|
146
|
+
# src: the sequence to the encoder (required).
|
147
|
+
# mask: the mask for the src sequence (optional).
|
148
|
+
# src_key_padding_mask: the mask for the src keys per batch (optional).
|
149
|
+
# Shape:
|
150
|
+
# see the docs in Transformer class.
|
151
|
+
def forward(src, mask: nil, src_key_padding_mask: nil)
|
152
|
+
output = @embedding.call(src) * Math.sqrt(@d_model)
|
153
|
+
output = @pos_encoder.call(output)
|
154
|
+
|
155
|
+
@layers.each { |mod|
|
156
|
+
output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
|
157
|
+
}
|
158
|
+
|
159
|
+
if @norm
|
160
|
+
output = @norm.call(output)
|
161
|
+
end
|
162
|
+
|
163
|
+
return output
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
|
168
|
+
class TransformerEncoderLayer < Torch::NN::Module
|
169
|
+
# TransformerEncoderLayer is made up of self-attn and feedforward network.
|
170
|
+
# This standard encoder layer is based on the paper "Attention Is All You Need".
|
171
|
+
# Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
172
|
+
# Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
173
|
+
# Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
174
|
+
# in a different way during application.
|
175
|
+
# Args:
|
176
|
+
# d_model: the number of expected features in the input (required).
|
177
|
+
# nhead: the number of heads in the multiheadattention models (required).
|
178
|
+
# dim_feedforward: the dimension of the feedforward network model (default=2048).
|
179
|
+
# dropout: the dropout value (default=0.1).
|
180
|
+
# activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
181
|
+
# Examples::
|
182
|
+
# >>> encoder_layer = TransformerEncoderLayer.new(512, 8)
|
183
|
+
# >>> src = Torch.rand(10, 32, 512)
|
184
|
+
# >>> out = encoder_layer.call(src)
|
185
|
+
def initialize(d_model, nhead, dim_feedforward:2048, dropout:0.1, activation:"relu")
|
186
|
+
super()
|
187
|
+
@self_attn = MultiheadAttention.new(d_model, nhead, dropout: dropout)
|
188
|
+
# Implementation of Feedforward model
|
189
|
+
@linear1 = Torch::NN::Linear.new(d_model, dim_feedforward)
|
190
|
+
@dropout = Torch::NN::Dropout.new(p: dropout)
|
191
|
+
@linear2 = Torch::NN::Linear.new(dim_feedforward, d_model)
|
192
|
+
|
193
|
+
@norm1 = Torch::NN::LayerNorm.new(d_model)
|
194
|
+
@norm2 = Torch::NN::LayerNorm.new(d_model)
|
195
|
+
@dropout1 = Torch::NN::Dropout.new(p: dropout)
|
196
|
+
@dropout2 = Torch::NN::Dropout.new(p: dropout)
|
197
|
+
|
198
|
+
@activation = _get_activation_fn(activation)
|
199
|
+
end
|
200
|
+
|
201
|
+
# Pass the input through the encoder layer.
|
202
|
+
# Args:
|
203
|
+
# src: the sequence to the encoder layer (required).
|
204
|
+
# src_mask: the mask for the src sequence (optional).
|
205
|
+
# src_key_padding_mask: the mask for the src keys per batch (optional).
|
206
|
+
# Shape:
|
207
|
+
# see the docs in Transformer class.
|
208
|
+
def forward(src, src_mask: nil, src_key_padding_mask: nil)
|
209
|
+
src2 = @self_attn.call(src, src, src, attn_mask: src_mask,
|
210
|
+
key_padding_mask: src_key_padding_mask)[0]
|
211
|
+
src = src + @dropout1.call(src2)
|
212
|
+
src = @norm1.call(src)
|
213
|
+
src2 = @linear2.call(@dropout.call(@activation.call(@linear1.call(src))))
|
214
|
+
src = src + @dropout2.call(src2)
|
215
|
+
src = @norm2.call(src)
|
216
|
+
return src
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
|
221
|
+
class TransformerDecoder < Torch::NN::Module
|
222
|
+
# TransformerDecoder is a stack of N decoder layers
|
223
|
+
# Args:
|
224
|
+
# decoder_layers: an array of instances of the TransformerDecoderLayer class (required).
|
225
|
+
# norm: the layer normalization component (optional).
|
226
|
+
# d_model: the number of expected features in the encoder/decoder inputs.
|
227
|
+
# vocab_size: size of vocabulary (number of different possible tokens).
|
228
|
+
# Examples::
|
229
|
+
# >>> decoder_layers = 6.times.map {|i| TransformerDecoderLayer.new(512, 8) }
|
230
|
+
# >>> transformer_decoder = TransformerDecoder.new(encoder_layers, nil, 512, 72, 0.1)
|
231
|
+
# >>> memory = Torch.rand(10, 32, 512)
|
232
|
+
# >>> tgt = Torch.rand(20, 32, 512)
|
233
|
+
# >>> out = transformer_decoder.call(tgt, memory)
|
234
|
+
def initialize(decoder_layers, norm=nil, d_model, vocab_size, dropout)
|
235
|
+
super()
|
236
|
+
@d_model = d_model
|
237
|
+
decoder_layers.each.with_index do |l, i|
|
238
|
+
instance_variable_set("@layer#{i}", l)
|
239
|
+
end
|
240
|
+
@layers = decoder_layers.length.times.map {|i| instance_variable_get("@layer#{i}") }
|
241
|
+
@num_layers = decoder_layers.length
|
242
|
+
@embedding = Torch::NN::Embedding.new(vocab_size, d_model)
|
243
|
+
@pos_encoder = PositionalEncoding.new(d_model, dropout: dropout)
|
244
|
+
@norm = norm
|
245
|
+
end
|
246
|
+
|
247
|
+
# Pass the inputs (and mask) through the decoder layer in turn.
|
248
|
+
# Args:
|
249
|
+
# tgt: the sequence to the decoder (required).
|
250
|
+
# memory: the sequence from the last layer of the encoder (required).
|
251
|
+
# tgt_mask: the mask for the tgt sequence (optional).
|
252
|
+
# memory_mask: the mask for the memory sequence (optional).
|
253
|
+
# tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
254
|
+
# memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
255
|
+
# Shape:
|
256
|
+
# see the docs in Transformer class.
|
257
|
+
def forward(tgt, memory, tgt_mask: nil,
|
258
|
+
memory_mask: nil, tgt_key_padding_mask: nil,
|
259
|
+
memory_key_padding_mask: nil)
|
260
|
+
|
261
|
+
output = @embedding.call(tgt) * Math.sqrt(@d_model)
|
262
|
+
output = @pos_encoder.call(output)
|
263
|
+
|
264
|
+
@layers.each { |mod|
|
265
|
+
output = mod.call(output, memory, tgt_mask: tgt_mask,
|
266
|
+
memory_mask: memory_mask,
|
267
|
+
tgt_key_padding_mask: tgt_key_padding_mask,
|
268
|
+
memory_key_padding_mask: memory_key_padding_mask)
|
269
|
+
}
|
270
|
+
|
271
|
+
if @norm
|
272
|
+
output = @norm.call(output)
|
273
|
+
end
|
274
|
+
|
275
|
+
return output
|
276
|
+
end
|
277
|
+
end
|
278
|
+
|
279
|
+
class TransformerDecoderLayer < Torch::NN::Module
|
280
|
+
# TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
281
|
+
# This standard decoder layer is based on the paper "Attention Is All You Need".
|
282
|
+
# Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
283
|
+
# Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
284
|
+
# Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
285
|
+
# in a different way during application.
|
286
|
+
# Args:
|
287
|
+
# d_model: the number of expected features in the input (required).
|
288
|
+
# nhead: the number of heads in the multiheadattention models (required).
|
289
|
+
# dim_feedforward: the dimension of the feedforward network model (default=2048).
|
290
|
+
# dropout: the dropout value (default=0.1).
|
291
|
+
# activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
292
|
+
# Examples::
|
293
|
+
# >>> decoder_layer = TransformerDecoderLayer(512, 8)
|
294
|
+
# >>> memory = Torch.rand(10, 32, 512)
|
295
|
+
# >>> tgt = Torch.rand(20, 32, 512)
|
296
|
+
# >>> out = decoder_layer.call(tgt, memory)
|
297
|
+
|
298
|
+
def initialize(d_model, nhead, dim_feedforward: 2048, dropout: 0.1, activation: "relu")
|
299
|
+
super()
|
300
|
+
@self_attn = MultiheadAttention.new(d_model, nhead, dropout: dropout)
|
301
|
+
@multihead_attn = MultiheadAttention.new(d_model, nhead, dropout: dropout)
|
302
|
+
# Implementation of Feedforward model
|
303
|
+
@linear1 = Torch::NN::Linear.new(d_model, dim_feedforward)
|
304
|
+
@dropout = Torch::NN::Dropout.new(p: dropout)
|
305
|
+
@linear2 = Torch::NN::Linear.new(dim_feedforward, d_model)
|
306
|
+
|
307
|
+
@norm1 = Torch::NN::LayerNorm.new(d_model)
|
308
|
+
@norm2 = Torch::NN::LayerNorm.new(d_model)
|
309
|
+
@norm3 = Torch::NN::LayerNorm.new(d_model)
|
310
|
+
@dropout1 = Torch::NN::Dropout.new(p: dropout)
|
311
|
+
@dropout2 = Torch::NN::Dropout.new(p: dropout)
|
312
|
+
@dropout3 = Torch::NN::Dropout.new(p: dropout)
|
313
|
+
|
314
|
+
@activation = _get_activation_fn(activation)
|
315
|
+
end
|
316
|
+
|
317
|
+
# Pass the inputs (and mask) through the decoder layer.
|
318
|
+
# Args:
|
319
|
+
# tgt: the sequence to the decoder layer (required).
|
320
|
+
# memory: the sequence from the last layer of the encoder (required).
|
321
|
+
# tgt_mask: the mask for the tgt sequence (optional).
|
322
|
+
# memory_mask: the mask for the memory sequence (optional).
|
323
|
+
# tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
324
|
+
# memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
325
|
+
# Shape:
|
326
|
+
# see the docs in Transformer class.
|
327
|
+
def forward(tgt, memory, tgt_mask: nil, memory_mask: nil,
|
328
|
+
tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
|
329
|
+
|
330
|
+
tgt2 = @self_attn.call(tgt, tgt, tgt, attn_mask: tgt_mask,
|
331
|
+
key_padding_mask: tgt_key_padding_mask)[0]
|
332
|
+
tgt = tgt + @dropout1.call(tgt2)
|
333
|
+
tgt = @norm1.call(tgt)
|
334
|
+
tgt2 = @multihead_attn.call(tgt, memory, memory, attn_mask: memory_mask,
|
335
|
+
key_padding_mask: memory_key_padding_mask)[0]
|
336
|
+
tgt = tgt + @dropout2.call(tgt2)
|
337
|
+
tgt = @norm2.call(tgt)
|
338
|
+
tgt2 = @linear2.call(@dropout.call(@activation.call(@linear1.call(tgt))))
|
339
|
+
tgt = tgt + @dropout3.call(tgt2)
|
340
|
+
tgt = @norm3.call(tgt)
|
341
|
+
return tgt
|
342
|
+
end
|
343
|
+
end
|
344
|
+
|
345
|
+
class PositionalEncoding < Torch::NN::Module
|
346
|
+
# PositionalEncoding module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.
|
347
|
+
def initialize(d_model, dropout: 0.1, max_len: 5000)
|
348
|
+
super()
|
349
|
+
@dropout = Torch::NN::Dropout.new(p: dropout)
|
350
|
+
|
351
|
+
pe = Torch.zeros(max_len, d_model)
|
352
|
+
position = Torch.arange(0, max_len, dtype: :float).unsqueeze(1)
|
353
|
+
div_term = Torch.exp(Torch.arange(0, d_model, 2).float() * (-Math.log(10000.0) / d_model))
|
354
|
+
sin = Torch.sin(position * div_term).t
|
355
|
+
cos = Torch.cos(position * div_term).t
|
356
|
+
pe.t!
|
357
|
+
pe.each.with_index do |row, i|
|
358
|
+
pe[i] = sin[i / 2] if i % 2 == 0
|
359
|
+
pe[i] = cos[(i-1)/2] if i % 2 != 0
|
360
|
+
end
|
361
|
+
pe.t!
|
362
|
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
363
|
+
register_buffer('pe', pe)
|
364
|
+
end
|
365
|
+
|
366
|
+
def forward(x)
|
367
|
+
x = x + pe.narrow(0, 0, x.size(0))
|
368
|
+
return x
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
|
373
|
+
|
374
|
+
def _get_activation_fn(activation)
|
375
|
+
if activation == "relu"
|
376
|
+
return Torch::NN::F.method(:relu)
|
377
|
+
elsif activation == "gelu"
|
378
|
+
return Torch::NN::F.method(:gelu)
|
379
|
+
end
|
380
|
+
|
381
|
+
raise RuntimeError, "activation should be relu/gelu, not %s" % activation
|
382
|
+
end
|