secryst 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,156 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/4ae832e1060c72cb89de1d9693629783dbe0c9a6/torch/csrc/api/include/torch/nn/functional/activation.h
2
+
3
+ require_relative 'multi_head_attention_forward'
4
+ module Secryst
5
+ class MultiheadAttention < Torch::NN::Module
6
+ # Allows the model to jointly attend to information
7
+ # from different representation subspaces.
8
+ # See reference: Attention Is All You Need
9
+ # .. math::
10
+ # \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
11
+ # \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
12
+ # Args:
13
+ # embed_dim: total dimension of the model.
14
+ # num_heads: parallel attention heads.
15
+ # dropout: a Dropout layer on attn_output_weights. Default: 0.0.
16
+ # bias: add bias as module parameter. Default: true.
17
+ # add_bias_kv: add bias to the key and value sequences at dim=0.
18
+ # add_zero_attn: add a new batch of zeros to the key and
19
+ # value sequences at dim=1.
20
+ # kdim: total number of features in key. Default: nil.
21
+ # vdim: total number of features in value. Default: nil.
22
+ # Note: if kdim and vdim are nil, they will be set to embed_dim such that
23
+ # query, key, and value have the same number of features.
24
+ # Examples::
25
+ # >>> multihead_attn = MultiheadAttention.new(embed_dim: embed_dim, num_heads: num_heads)
26
+ # >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
27
+ # bias_k: Optional[Torch::Tensor]
28
+ # bias_v: Optional[Torch::Tensor]
29
+
30
+ def initialize(embed_dim, num_heads, dropout:0.0, bias: true, add_bias_kv: false, add_zero_attn: false, kdim: nil, vdim: nil)
31
+ super()
32
+ @embed_dim = embed_dim
33
+ @kdim = kdim || embed_dim
34
+ @vdim = vdim || embed_dim
35
+ @_qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
36
+
37
+ @num_heads = num_heads
38
+ @dropout = dropout
39
+ @head_dim = embed_dim / num_heads
40
+ raise ArgumentError, "embed_dim must be divisible by num_heads" if @head_dim * num_heads != @embed_dim
41
+
42
+ if !@_qkv_same_embed_dim
43
+ @q_proj_weight = Torch::NN::Parameter.new(Torch::Tensor.new(embed_dim, embed_dim))
44
+ @k_proj_weight = Torch::NN::Parameter.new(Torch::Tensor.new(embed_dim, @kdim))
45
+ @v_proj_weight = Torch::NN::Parameter.new(Torch::Tensor.new(embed_dim, @vdim))
46
+ register_parameter('in_proj_weight', nil)
47
+ else
48
+ @in_proj_weight = Torch::NN::Parameter.new(Torch.empty(3 * embed_dim, embed_dim))
49
+ register_parameter('q_proj_weight', nil)
50
+ register_parameter('k_proj_weight', nil)
51
+ register_parameter('v_proj_weight', nil)
52
+ end
53
+
54
+ if bias
55
+ @in_proj_bias = Torch::NN::Parameter.new(Torch.empty(3 * embed_dim))
56
+ else
57
+ register_parameter('in_proj_bias', nil)
58
+ end
59
+ @out_proj = Torch::NN::Linear.new(embed_dim, embed_dim)
60
+
61
+ if add_bias_kv
62
+ @bias_k = Torch::NN::Parameter.new(Torch.empty(1, 1, embed_dim))
63
+ @bias_v = Torch::NN::Parameter.new(Torch.empty(1, 1, embed_dim))
64
+ else
65
+ @bias_k = @bias_v = nil
66
+ end
67
+
68
+ @add_zero_attn = add_zero_attn
69
+
70
+ _reset_parameters
71
+ end
72
+
73
+ def _reset_parameters
74
+ if @_qkv_same_embed_dim
75
+ Torch::NN::Init.xavier_uniform!(@in_proj_weight)
76
+ else
77
+ Torch::NN::Init.xavier_uniform!(@q_proj_weight)
78
+ Torch::NN::Init.xavier_uniform!(@k_proj_weight)
79
+ Torch::NN::Init.xavier_uniform!(@v_proj_weight)
80
+ end
81
+
82
+ if @in_proj_bias
83
+ Torch::NN::Init.constant!(@in_proj_bias, 0.0)
84
+ Torch::NN::Init.constant!(@out_proj.bias, 0.0)
85
+ end
86
+
87
+ if @bias_k
88
+ Torch::NN::Init.xavier_normal!(@bias_k)
89
+ end
90
+
91
+ if @bias_v
92
+ Torch::NN::Init.xavier_normal!(@bias_v)
93
+ end
94
+ end
95
+
96
+ # Args:
97
+ # query, key, value: map a query and a set of key-value pairs to an output.
98
+ # See "Attention Is All You Need" for more details.
99
+ # key_padding_mask: if provided, specified padding elements in the key will
100
+ # be ignored by the attention. When given a binary mask and a value is true,
101
+ # the corresponding value on the attention layer will be ignored. When given
102
+ # a byte mask and a value is non-zero, the corresponding value on the attention
103
+ # layer will be ignored
104
+ # need_weights: output attn_output_weights.
105
+ # attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
106
+ # the batches while a 3D mask allows to specify a different mask for the entries of each batch.
107
+ # Shape:
108
+ # - Inputs:
109
+ # - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
110
+ # the embedding dimension.
111
+ # - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
112
+ # the embedding dimension.
113
+ # - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
114
+ # the embedding dimension.
115
+ # - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
116
+ # If a ByteTensor is provided, the non-zero positions will be ignored while the position
117
+ # with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
118
+ # value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
119
+ # - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
120
+ # 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
121
+ # S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
122
+ # positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
123
+ # while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
124
+ # is not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
125
+ # is provided, it will be added to the attention weight.
126
+ # - Outputs:
127
+ # - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
128
+ # E is the embedding dimension.
129
+ # - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
130
+ # L is the target sequence length, S is the source sequence length.
131
+ def forward(query, key, value, key_padding_mask:nil,
132
+ need_weights:true, attn_mask:nil)
133
+ if !@_qkv_same_embed_dim
134
+ return Secryst::MultiHeadAttentionForward.multi_head_attention_forward(
135
+ query, key, value, @embed_dim, @num_heads,
136
+ @in_proj_weight, @in_proj_bias,
137
+ @bias_k, @bias_v, @add_zero_attn,
138
+ @dropout, @out_proj.weight, @out_proj.bias,
139
+ training: @training,
140
+ key_padding_mask: key_padding_mask, need_weights: need_weights,
141
+ attn_mask: attn_mask, use_separate_proj_weight: true,
142
+ q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight,
143
+ v_proj_weight: @v_proj_weight)
144
+ else
145
+ return Secryst::MultiHeadAttentionForward.multi_head_attention_forward(
146
+ query, key, value, @embed_dim, @num_heads,
147
+ @in_proj_weight, @in_proj_bias,
148
+ @bias_k, @bias_v, @add_zero_attn,
149
+ @dropout, @out_proj.weight, @out_proj.bias,
150
+ training: @training,
151
+ key_padding_mask: key_padding_mask, need_weights: need_weights,
152
+ attn_mask: attn_mask)
153
+ end
154
+ end
155
+ end
156
+ end
@@ -0,0 +1,235 @@
1
+ module Secryst
2
+ class Trainer
3
+
4
+ def initialize(
5
+ model:,
6
+ batch_size:,
7
+ lr:,
8
+ data_input:,
9
+ data_target:,
10
+ hyperparameters:,
11
+ max_epochs: nil,
12
+ log_interval: 1,
13
+ checkpoint_every:,
14
+ checkpoint_dir:,
15
+ scheduler_step_size:,
16
+ gamma:
17
+ )
18
+ @data_input = File.readlines(data_input, chomp: true)
19
+ @data_target = File.readlines(data_target, chomp: true)
20
+
21
+ @device = "cpu"
22
+ @lr = lr
23
+ @scheduler_step_size = scheduler_step_size
24
+ @gamma = gamma
25
+ @batch_size = batch_size
26
+ @model_name = model
27
+ @max_epochs = max_epochs
28
+ @log_interval = log_interval
29
+ @checkpoint_every = checkpoint_every
30
+ @checkpoint_dir = checkpoint_dir
31
+ FileUtils.mkdir_p(@checkpoint_dir)
32
+ generate_vocabs_and_data
33
+ save_vocabs
34
+
35
+ case model
36
+ when 'transformer'
37
+ @model = Secryst::Transformer.new(hyperparameters.merge({
38
+ input_vocab_size: @input_vocab.length,
39
+ target_vocab_size: @target_vocab.length,
40
+ }))
41
+ else
42
+ raise ArgumentError, 'Only transformer model is currently supported'
43
+ end
44
+ end
45
+
46
+ def train
47
+ best_model = nil
48
+ best_val_loss = 1.0/0.0 # infinity
49
+
50
+ return unless @model_name == 'transformer'
51
+
52
+ criterion = Torch::NN::CrossEntropyLoss.new(ignore_index: index_of('<pad>')).to(@device)
53
+ optimizer = Torch::Optim::SGD.new(@model.parameters, lr: @lr)
54
+ scheduler = Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: @scheduler_step_size, gamma: @gamma)
55
+
56
+ total_loss = 0.0
57
+ start_time = Time.now
58
+ ntokens = @target_vocab.length
59
+ epoch = 0
60
+
61
+ loop do
62
+ epoch_start_time = Time.now
63
+ @model.train
64
+ @train_data.each.with_index do |batch, i|
65
+ inputs, targets, decoder_inputs, src_mask, tgt_mask, memory_mask = batch
66
+ inputs = Torch.tensor(inputs).t
67
+ decoder_inputs = Torch.tensor(decoder_inputs).t
68
+ targets = Torch.tensor(targets).t
69
+ src_key_padding_mask = inputs.t.eq(1)
70
+ tgt_key_padding_mask = decoder_inputs.t.eq(1)
71
+
72
+ optimizer.zero_grad
73
+ opts = {
74
+ # src_mask: src_mask,
75
+ tgt_mask: tgt_mask,
76
+ # memory_mask: memory_mask,
77
+ src_key_padding_mask: src_key_padding_mask,
78
+ tgt_key_padding_mask: tgt_key_padding_mask,
79
+ memory_key_padding_mask: src_key_padding_mask,
80
+ }
81
+ output = @model.call(inputs, decoder_inputs, opts)
82
+ loss = criterion.call(output.transpose(0,1).reshape(-1, ntokens), targets.t.view(-1))
83
+ loss.backward
84
+ ClipGradNorm.clip_grad_norm(@model.parameters, max_norm: 0.5)
85
+ optimizer.step
86
+
87
+ # puts "i[#{i}] loss: #{loss}"
88
+ total_loss += loss.item()
89
+ if ( (i + 1) % @log_interval == 0 )
90
+ cur_loss = total_loss / @log_interval
91
+ elapsed = Time.now - start_time
92
+ puts "| epoch #{epoch} | #{i + 1}/#{@train_data.length} batches | "\
93
+ "lr #{scheduler.get_lr()[0].round(4)} | ms/batch #{(1000*elapsed.to_f / @log_interval).round} | "\
94
+ "loss #{cur_loss.round(5)} | ppl #{Math.exp(cur_loss).round(5)}"
95
+ total_loss = 0
96
+ start_time = Time.now
97
+ end
98
+ end
99
+
100
+ if epoch > 0 && epoch % @checkpoint_every == 0
101
+ puts ">> Saving checkpoint '#{@checkpoint_dir}/checkpoint-#{epoch}.pth'"
102
+ Torch.save(@model.state_dict, "#{@checkpoint_dir}/checkpoint-#{epoch}.pth")
103
+ end
104
+
105
+ # Evaluate
106
+ @model.eval()
107
+ total_loss = 0.0
108
+ Torch.no_grad do
109
+ @eval_data.each.with_index do |batch, i|
110
+ inputs, targets, decoder_inputs, src_mask, tgt_mask, memory_mask = batch
111
+ inputs = Torch.tensor(inputs).t
112
+ decoder_inputs = Torch.tensor(decoder_inputs).t
113
+ targets = Torch.tensor(targets).t
114
+ src_key_padding_mask = inputs.t.eq(1)
115
+ tgt_key_padding_mask = decoder_inputs.t.eq(1)
116
+
117
+ opts = {
118
+ # src_mask: src_mask,
119
+ tgt_mask: tgt_mask,
120
+ # memory_mask: memory_mask,
121
+ src_key_padding_mask: src_key_padding_mask,
122
+ tgt_key_padding_mask: tgt_key_padding_mask,
123
+ memory_key_padding_mask: src_key_padding_mask,
124
+ }
125
+ output = @model.call(inputs, decoder_inputs, **opts)
126
+ output_flat = output.transpose(0,1).reshape(-1, ntokens)
127
+
128
+ total_loss += criterion.call(output_flat, targets.t.view(-1)).item
129
+ end
130
+ total_loss = total_loss / @eval_data.length
131
+ puts('-' * 89)
132
+ puts "| end of epoch #{epoch} | time: #{(Time.now - epoch_start_time).round(3)}s | "\
133
+ " valid loss #{total_loss.round(5)} | valid ppl #{Math.exp(total_loss).round(5)} "
134
+ puts('-' * 89)
135
+ if total_loss < best_val_loss
136
+ best_model = @model
137
+ best_val_loss = total_loss
138
+ end
139
+ end
140
+ scheduler.step
141
+
142
+ epoch += 1
143
+ break if @max_epochs && @max_epochs < epoch
144
+ end
145
+ end
146
+
147
+ private
148
+
149
+ def generate_vocabs_and_data
150
+ input_texts = []
151
+ target_texts = []
152
+ input_vocab_counter = Hash.new(0)
153
+ target_vocab_counter = Hash.new(0)
154
+
155
+ @data_input.each do |input_text|
156
+ input_text.strip!
157
+ input_texts.push(input_text)
158
+ input_text.each_char do |char|
159
+ input_vocab_counter[char] += 1
160
+ end
161
+ end
162
+
163
+ @data_target.each do |target_text|
164
+ target_text.strip!
165
+ target_texts.push(target_text)
166
+ target_text.each_char do |char|
167
+ target_vocab_counter[char] += 1
168
+ end
169
+ end
170
+
171
+ @input_vocab = Vocab.new(input_vocab_counter)
172
+ @target_vocab = Vocab.new(target_vocab_counter)
173
+
174
+ # Generate train, eval, and test batches
175
+ seed = 1
176
+ zipped_texts = input_texts.zip(target_texts)
177
+ zipped_texts = zipped_texts.shuffle(random: Random.new(seed))
178
+
179
+ # train - 90%, eval - 7%, test - 3%
180
+ train_texts = zipped_texts[0..(zipped_texts.length*0.9).to_i]
181
+ eval_texts = zipped_texts[(zipped_texts.length*0.9).to_i + 1..(zipped_texts.length*0.97).to_i]
182
+ test_texts = zipped_texts[(zipped_texts.length*0.97).to_i+1..-1]
183
+
184
+ # prepare batches
185
+ @train_data = batchify(train_texts)
186
+ @eval_data = batchify(eval_texts)
187
+ @test_data = batchify(test_texts)
188
+
189
+ end
190
+
191
+ def pad(arr, length, no_eos:false, no_sos:false)
192
+ if !no_eos
193
+ arr = arr + ["<eos>"]
194
+ end
195
+ if !no_sos
196
+ arr = ["<sos>"] + arr
197
+ end
198
+ arr.fill("<pad>", arr.length...length)
199
+ end
200
+
201
+ def index_of(token)
202
+ @target_vocab.stoi[token]
203
+ end
204
+
205
+ def batchify(data)
206
+ batches = []
207
+
208
+ (1 + data.length / @batch_size).times do |i|
209
+ input_data = data[i*@batch_size, @batch_size].transpose[0]
210
+ decoder_input_data = data[i*@batch_size, @batch_size].transpose[1]
211
+ target_data = data[i*@batch_size, @batch_size].transpose[1]
212
+ max_input_seq_length = input_data.max_by(&:length).length + 2
213
+ max_target_seq_length = target_data.max_by(&:length).length + 1
214
+ src_mask = Torch.triu(Torch.ones(max_input_seq_length,max_input_seq_length)).eq(0).transpose(0,1)
215
+ tgt_mask = Torch.triu(Torch.ones(max_target_seq_length,max_target_seq_length)).eq(0).transpose(0,1)
216
+ memory_mask = Torch.triu(Torch.ones(max_input_seq_length,max_target_seq_length)).eq(0).transpose(0,1)
217
+ batches << [
218
+ input_data.map {|line| pad(line.chars, max_input_seq_length).map {|c| @input_vocab[c]} },
219
+ target_data.map {|line| pad(line.chars, max_target_seq_length, no_sos: true).map {|c| @target_vocab[c]} },
220
+ decoder_input_data.map {|line| pad(line.chars, max_target_seq_length, no_eos: true).map {|c| @target_vocab[c]} },
221
+ src_mask,
222
+ tgt_mask,
223
+ memory_mask
224
+ ]
225
+ end
226
+
227
+ batches
228
+ end
229
+
230
+ def save_vocabs
231
+ File.write("#{@checkpoint_dir}/input_vocab.json", JSON.generate(@input_vocab.freqs))
232
+ File.write("#{@checkpoint_dir}/target_vocab.json", JSON.generate(@target_vocab.freqs))
233
+ end
234
+ end
235
+ end
@@ -0,0 +1,382 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/626e410e1dedcdb9d5a410a8827cc7a8a9fbcce1/torch/nn/modules/transformer.py
2
+
3
+ module Secryst
4
+ class Transformer < Torch::NN::Module
5
+ # A transformer model. User is able to modify the attributes as needed. The architecture
6
+ # is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
7
+ # Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
8
+ # Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
9
+ # Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
10
+ # model with corresponding parameters.
11
+ # Args:
12
+ # d_model: the number of expected features in the encoder/decoder inputs (default=512).
13
+ # nhead: the number of heads in the multiheadattention models (default=8).
14
+ # num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
15
+ # num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
16
+ # dim_feedforward: the dimension of the feedforward network model (default=2048).
17
+ # dropout: the dropout value (default=0.1).
18
+ # activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
19
+ # custom_encoder: custom encoder (default=nil).
20
+ # custom_decoder: custom decoder (default=nil).
21
+ # input_vocab_size: size of vocabulary for input sequence (number of different possible tokens).
22
+ # target_vocab_size: size of vocabulary for target sequence (number of different possible tokens).
23
+ # Examples::
24
+ # >>> transformer_model = Transformer.new(nhead: 16, num_encoder_layers: 12)
25
+ # >>> src = Torch.rand((10, 32, 512))
26
+ # >>> tgt = Torch.rand((20, 32, 512))
27
+ # >>> out = transformer_model.call(src, tgt)
28
+ def initialize(d_model: 512, nhead: 8, num_encoder_layers: 6, num_decoder_layers: 6,
29
+ dim_feedforward: 2048, dropout: 0.1, activation: 'relu', custom_encoder: nil, custom_decoder: nil, input_vocab_size:, target_vocab_size:)
30
+
31
+ super()
32
+
33
+ if custom_encoder
34
+ @encoder = custom_encoder
35
+ else
36
+ encoder_layers = num_encoder_layers.times.map { TransformerEncoderLayer.new(d_model, nhead, dim_feedforward: dim_feedforward, dropout: dropout, activation: activation) }
37
+ encoder_norm = Torch::NN::LayerNorm.new(d_model)
38
+ @encoder = TransformerEncoder.new(encoder_layers, encoder_norm, d_model, input_vocab_size, dropout)
39
+ end
40
+
41
+ if custom_decoder
42
+ @decoder = custom_decoder
43
+ else
44
+ decoder_layers = num_decoder_layers.times.map { TransformerDecoderLayer.new(d_model, nhead, dim_feedforward: dim_feedforward, dropout: dropout, activation: activation) }
45
+ decoder_norm = Torch::NN::LayerNorm.new(d_model)
46
+ @decoder = TransformerDecoder.new(decoder_layers, decoder_norm, d_model, target_vocab_size, dropout)
47
+ end
48
+
49
+ @linear = Torch::NN::Linear.new(d_model, target_vocab_size)
50
+ @softmax = Torch::NN::LogSoftmax.new(dim: -1)
51
+ _reset_parameters()
52
+
53
+ @d_model = d_model
54
+ @nhead = nhead
55
+
56
+ end
57
+
58
+ # Take in and process masked source/target sequences.
59
+ # Args:
60
+ # src: the sequence to the encoder (required).
61
+ # tgt: the sequence to the decoder (required).
62
+ # src_mask: the additive mask for the src sequence (optional).
63
+ # tgt_mask: the additive mask for the tgt sequence (optional).
64
+ # memory_mask: the additive mask for the encoder output (optional).
65
+ # src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
66
+ # tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
67
+ # memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
68
+ # Shape:
69
+ # - src: :math:`(S, N, E)`.
70
+ # - tgt: :math:`(T, N, E)`.
71
+ # - src_mask: :math:`(S, S)`.
72
+ # - tgt_mask: :math:`(T, T)`.
73
+ # - memory_mask: :math:`(T, S)`.
74
+ # - src_key_padding_mask: :math:`(N, S)`.
75
+ # - tgt_key_padding_mask: :math:`(N, T)`.
76
+ # - memory_key_padding_mask: :math:`(N, S)`.
77
+ # Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
78
+ # positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
79
+ # while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true``
80
+ # are not allowed to attend while ``false`` values will be unchanged. If a FloatTensor
81
+ # is provided, it will be added to the attention weight.
82
+ # [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
83
+ # the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
84
+ # positions will be unchanged. If a BoolTensor is provided, the positions with the
85
+ # value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged.
86
+ # - output: :math:`(T, N, E)`.
87
+ # Note: Due to the multi-head attention architecture in the transformer model,
88
+ # the output sequence length of a transformer is same as the input sequence
89
+ # (i.e. target) length of the decode.
90
+ # where S is the source sequence length, T is the target sequence length, N is the
91
+ # batch size, E is the feature number
92
+ # Examples:
93
+ # >>> output = transformer_model.call(src, tgt, src_mask: src_mask, tgt_mask: tgt_mask)
94
+ def forward(src, tgt, src_mask: nil, tgt_mask: nil,
95
+ memory_mask: nil, src_key_padding_mask: nil,
96
+ tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
97
+ if src.size(1) != tgt.size(1)
98
+ raise RuntimeError, "the batch number of src and tgt must be equal"
99
+ end
100
+
101
+ memory = @encoder.call(src, mask: src_mask, src_key_padding_mask: src_key_padding_mask)
102
+ output = @decoder.call(tgt, memory, tgt_mask: tgt_mask, memory_mask: memory_mask,
103
+ tgt_key_padding_mask: tgt_key_padding_mask,
104
+ memory_key_padding_mask: memory_key_padding_mask)
105
+ output = @linear.call(output)
106
+ output = @softmax.call(output)
107
+
108
+ return output
109
+ end
110
+
111
+ def _reset_parameters
112
+ parameters.each do |p|
113
+ Torch::NN::Init.xavier_uniform!(p) if p.dim > 1
114
+ end
115
+ end
116
+ end
117
+
118
+
119
+ class TransformerEncoder < Torch::NN::Module
120
+ # TransformerEncoder is a stack of N encoder layers
121
+ # Args:
122
+ # encoder_layers: an array of instances of the TransformerEncoderLayer class (required).
123
+ # norm: the layer normalization component (optional).
124
+ # d_model: the number of expected features in the encoder/decoder inputs.
125
+ # vocab_size: size of vocabulary (number of different possible tokens).
126
+ # Examples::
127
+ # >>> encoder_layers = 6.times.map {|i| TransformerEncoderLayer.new(512, 8) }
128
+ # >>> transformer_encoder = nn.TransformerEncoder(encoder_layers, nil, 512, 72, 0.1)
129
+ # >>> src = Torch.rand(10, 32, 512)
130
+ # >>> out = transformer_encoder.call(src)
131
+ def initialize(encoder_layers, norm=nil, d_model, vocab_size, dropout)
132
+ super()
133
+ @d_model = d_model
134
+ encoder_layers.each.with_index do |l, i|
135
+ instance_variable_set("@layer#{i}", l)
136
+ end
137
+ @layers = encoder_layers.length.times.map {|i| instance_variable_get("@layer#{i}") }
138
+ @num_layers = encoder_layers.length
139
+ @embedding = Torch::NN::Embedding.new(vocab_size, d_model)
140
+ @pos_encoder = PositionalEncoding.new(d_model, dropout: dropout)
141
+ @norm = norm
142
+ end
143
+
144
+ # Pass the input through the encoder layers in turn.
145
+ # Args:
146
+ # src: the sequence to the encoder (required).
147
+ # mask: the mask for the src sequence (optional).
148
+ # src_key_padding_mask: the mask for the src keys per batch (optional).
149
+ # Shape:
150
+ # see the docs in Transformer class.
151
+ def forward(src, mask: nil, src_key_padding_mask: nil)
152
+ output = @embedding.call(src) * Math.sqrt(@d_model)
153
+ output = @pos_encoder.call(output)
154
+
155
+ @layers.each { |mod|
156
+ output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
157
+ }
158
+
159
+ if @norm
160
+ output = @norm.call(output)
161
+ end
162
+
163
+ return output
164
+ end
165
+ end
166
+
167
+
168
+ class TransformerEncoderLayer < Torch::NN::Module
169
+ # TransformerEncoderLayer is made up of self-attn and feedforward network.
170
+ # This standard encoder layer is based on the paper "Attention Is All You Need".
171
+ # Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
172
+ # Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
173
+ # Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
174
+ # in a different way during application.
175
+ # Args:
176
+ # d_model: the number of expected features in the input (required).
177
+ # nhead: the number of heads in the multiheadattention models (required).
178
+ # dim_feedforward: the dimension of the feedforward network model (default=2048).
179
+ # dropout: the dropout value (default=0.1).
180
+ # activation: the activation function of intermediate layer, relu or gelu (default=relu).
181
+ # Examples::
182
+ # >>> encoder_layer = TransformerEncoderLayer.new(512, 8)
183
+ # >>> src = Torch.rand(10, 32, 512)
184
+ # >>> out = encoder_layer.call(src)
185
+ def initialize(d_model, nhead, dim_feedforward:2048, dropout:0.1, activation:"relu")
186
+ super()
187
+ @self_attn = MultiheadAttention.new(d_model, nhead, dropout: dropout)
188
+ # Implementation of Feedforward model
189
+ @linear1 = Torch::NN::Linear.new(d_model, dim_feedforward)
190
+ @dropout = Torch::NN::Dropout.new(p: dropout)
191
+ @linear2 = Torch::NN::Linear.new(dim_feedforward, d_model)
192
+
193
+ @norm1 = Torch::NN::LayerNorm.new(d_model)
194
+ @norm2 = Torch::NN::LayerNorm.new(d_model)
195
+ @dropout1 = Torch::NN::Dropout.new(p: dropout)
196
+ @dropout2 = Torch::NN::Dropout.new(p: dropout)
197
+
198
+ @activation = _get_activation_fn(activation)
199
+ end
200
+
201
+ # Pass the input through the encoder layer.
202
+ # Args:
203
+ # src: the sequence to the encoder layer (required).
204
+ # src_mask: the mask for the src sequence (optional).
205
+ # src_key_padding_mask: the mask for the src keys per batch (optional).
206
+ # Shape:
207
+ # see the docs in Transformer class.
208
+ def forward(src, src_mask: nil, src_key_padding_mask: nil)
209
+ src2 = @self_attn.call(src, src, src, attn_mask: src_mask,
210
+ key_padding_mask: src_key_padding_mask)[0]
211
+ src = src + @dropout1.call(src2)
212
+ src = @norm1.call(src)
213
+ src2 = @linear2.call(@dropout.call(@activation.call(@linear1.call(src))))
214
+ src = src + @dropout2.call(src2)
215
+ src = @norm2.call(src)
216
+ return src
217
+ end
218
+ end
219
+
220
+
221
+ class TransformerDecoder < Torch::NN::Module
222
+ # TransformerDecoder is a stack of N decoder layers
223
+ # Args:
224
+ # decoder_layers: an array of instances of the TransformerDecoderLayer class (required).
225
+ # norm: the layer normalization component (optional).
226
+ # d_model: the number of expected features in the encoder/decoder inputs.
227
+ # vocab_size: size of vocabulary (number of different possible tokens).
228
+ # Examples::
229
+ # >>> decoder_layers = 6.times.map {|i| TransformerDecoderLayer.new(512, 8) }
230
+ # >>> transformer_decoder = TransformerDecoder.new(encoder_layers, nil, 512, 72, 0.1)
231
+ # >>> memory = Torch.rand(10, 32, 512)
232
+ # >>> tgt = Torch.rand(20, 32, 512)
233
+ # >>> out = transformer_decoder.call(tgt, memory)
234
+ def initialize(decoder_layers, norm=nil, d_model, vocab_size, dropout)
235
+ super()
236
+ @d_model = d_model
237
+ decoder_layers.each.with_index do |l, i|
238
+ instance_variable_set("@layer#{i}", l)
239
+ end
240
+ @layers = decoder_layers.length.times.map {|i| instance_variable_get("@layer#{i}") }
241
+ @num_layers = decoder_layers.length
242
+ @embedding = Torch::NN::Embedding.new(vocab_size, d_model)
243
+ @pos_encoder = PositionalEncoding.new(d_model, dropout: dropout)
244
+ @norm = norm
245
+ end
246
+
247
+ # Pass the inputs (and mask) through the decoder layer in turn.
248
+ # Args:
249
+ # tgt: the sequence to the decoder (required).
250
+ # memory: the sequence from the last layer of the encoder (required).
251
+ # tgt_mask: the mask for the tgt sequence (optional).
252
+ # memory_mask: the mask for the memory sequence (optional).
253
+ # tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
254
+ # memory_key_padding_mask: the mask for the memory keys per batch (optional).
255
+ # Shape:
256
+ # see the docs in Transformer class.
257
+ def forward(tgt, memory, tgt_mask: nil,
258
+ memory_mask: nil, tgt_key_padding_mask: nil,
259
+ memory_key_padding_mask: nil)
260
+
261
+ output = @embedding.call(tgt) * Math.sqrt(@d_model)
262
+ output = @pos_encoder.call(output)
263
+
264
+ @layers.each { |mod|
265
+ output = mod.call(output, memory, tgt_mask: tgt_mask,
266
+ memory_mask: memory_mask,
267
+ tgt_key_padding_mask: tgt_key_padding_mask,
268
+ memory_key_padding_mask: memory_key_padding_mask)
269
+ }
270
+
271
+ if @norm
272
+ output = @norm.call(output)
273
+ end
274
+
275
+ return output
276
+ end
277
+ end
278
+
279
+ class TransformerDecoderLayer < Torch::NN::Module
280
+ # TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
281
+ # This standard decoder layer is based on the paper "Attention Is All You Need".
282
+ # Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
283
+ # Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
284
+ # Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
285
+ # in a different way during application.
286
+ # Args:
287
+ # d_model: the number of expected features in the input (required).
288
+ # nhead: the number of heads in the multiheadattention models (required).
289
+ # dim_feedforward: the dimension of the feedforward network model (default=2048).
290
+ # dropout: the dropout value (default=0.1).
291
+ # activation: the activation function of intermediate layer, relu or gelu (default=relu).
292
+ # Examples::
293
+ # >>> decoder_layer = TransformerDecoderLayer(512, 8)
294
+ # >>> memory = Torch.rand(10, 32, 512)
295
+ # >>> tgt = Torch.rand(20, 32, 512)
296
+ # >>> out = decoder_layer.call(tgt, memory)
297
+
298
+ def initialize(d_model, nhead, dim_feedforward: 2048, dropout: 0.1, activation: "relu")
299
+ super()
300
+ @self_attn = MultiheadAttention.new(d_model, nhead, dropout: dropout)
301
+ @multihead_attn = MultiheadAttention.new(d_model, nhead, dropout: dropout)
302
+ # Implementation of Feedforward model
303
+ @linear1 = Torch::NN::Linear.new(d_model, dim_feedforward)
304
+ @dropout = Torch::NN::Dropout.new(p: dropout)
305
+ @linear2 = Torch::NN::Linear.new(dim_feedforward, d_model)
306
+
307
+ @norm1 = Torch::NN::LayerNorm.new(d_model)
308
+ @norm2 = Torch::NN::LayerNorm.new(d_model)
309
+ @norm3 = Torch::NN::LayerNorm.new(d_model)
310
+ @dropout1 = Torch::NN::Dropout.new(p: dropout)
311
+ @dropout2 = Torch::NN::Dropout.new(p: dropout)
312
+ @dropout3 = Torch::NN::Dropout.new(p: dropout)
313
+
314
+ @activation = _get_activation_fn(activation)
315
+ end
316
+
317
+ # Pass the inputs (and mask) through the decoder layer.
318
+ # Args:
319
+ # tgt: the sequence to the decoder layer (required).
320
+ # memory: the sequence from the last layer of the encoder (required).
321
+ # tgt_mask: the mask for the tgt sequence (optional).
322
+ # memory_mask: the mask for the memory sequence (optional).
323
+ # tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
324
+ # memory_key_padding_mask: the mask for the memory keys per batch (optional).
325
+ # Shape:
326
+ # see the docs in Transformer class.
327
+ def forward(tgt, memory, tgt_mask: nil, memory_mask: nil,
328
+ tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
329
+
330
+ tgt2 = @self_attn.call(tgt, tgt, tgt, attn_mask: tgt_mask,
331
+ key_padding_mask: tgt_key_padding_mask)[0]
332
+ tgt = tgt + @dropout1.call(tgt2)
333
+ tgt = @norm1.call(tgt)
334
+ tgt2 = @multihead_attn.call(tgt, memory, memory, attn_mask: memory_mask,
335
+ key_padding_mask: memory_key_padding_mask)[0]
336
+ tgt = tgt + @dropout2.call(tgt2)
337
+ tgt = @norm2.call(tgt)
338
+ tgt2 = @linear2.call(@dropout.call(@activation.call(@linear1.call(tgt))))
339
+ tgt = tgt + @dropout3.call(tgt2)
340
+ tgt = @norm3.call(tgt)
341
+ return tgt
342
+ end
343
+ end
344
+
345
+ class PositionalEncoding < Torch::NN::Module
346
+ # PositionalEncoding module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.
347
+ def initialize(d_model, dropout: 0.1, max_len: 5000)
348
+ super()
349
+ @dropout = Torch::NN::Dropout.new(p: dropout)
350
+
351
+ pe = Torch.zeros(max_len, d_model)
352
+ position = Torch.arange(0, max_len, dtype: :float).unsqueeze(1)
353
+ div_term = Torch.exp(Torch.arange(0, d_model, 2).float() * (-Math.log(10000.0) / d_model))
354
+ sin = Torch.sin(position * div_term).t
355
+ cos = Torch.cos(position * div_term).t
356
+ pe.t!
357
+ pe.each.with_index do |row, i|
358
+ pe[i] = sin[i / 2] if i % 2 == 0
359
+ pe[i] = cos[(i-1)/2] if i % 2 != 0
360
+ end
361
+ pe.t!
362
+ pe = pe.unsqueeze(0).transpose(0, 1)
363
+ register_buffer('pe', pe)
364
+ end
365
+
366
+ def forward(x)
367
+ x = x + pe.narrow(0, 0, x.size(0))
368
+ return x
369
+ end
370
+ end
371
+ end
372
+
373
+
374
+ def _get_activation_fn(activation)
375
+ if activation == "relu"
376
+ return Torch::NN::F.method(:relu)
377
+ elsif activation == "gelu"
378
+ return Torch::NN::F.method(:gelu)
379
+ end
380
+
381
+ raise RuntimeError, "activation should be relu/gelu, not %s" % activation
382
+ end