wolof-translate 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. wolof_translate/__init__.py +73 -0
  2. wolof_translate/data/__init__.py +0 -0
  3. wolof_translate/data/dataset_v1.py +151 -0
  4. wolof_translate/data/dataset_v2.py +187 -0
  5. wolof_translate/data/dataset_v3.py +187 -0
  6. wolof_translate/data/dataset_v3_2.py +187 -0
  7. wolof_translate/data/dataset_v4.py +202 -0
  8. wolof_translate/data/dataset_v5.py +65 -0
  9. wolof_translate/models/__init__.py +0 -0
  10. wolof_translate/models/transformers/__init__.py +0 -0
  11. wolof_translate/models/transformers/main.py +865 -0
  12. wolof_translate/models/transformers/main_2.py +362 -0
  13. wolof_translate/models/transformers/optimization.py +41 -0
  14. wolof_translate/models/transformers/position.py +46 -0
  15. wolof_translate/models/transformers/size.py +44 -0
  16. wolof_translate/pipe/__init__.py +1 -0
  17. wolof_translate/pipe/nlp_pipeline.py +512 -0
  18. wolof_translate/tokenizers/__init__.py +0 -0
  19. wolof_translate/trainers/__init__.py +0 -0
  20. wolof_translate/trainers/transformer_trainer.py +760 -0
  21. wolof_translate/trainers/transformer_trainer_custom.py +882 -0
  22. wolof_translate/trainers/transformer_trainer_ml.py +925 -0
  23. wolof_translate/trainers/transformer_trainer_ml_.py +1042 -0
  24. wolof_translate/utils/__init__.py +1 -0
  25. wolof_translate/utils/bucket_iterator.py +143 -0
  26. wolof_translate/utils/database_manager.py +116 -0
  27. wolof_translate/utils/display_predictions.py +162 -0
  28. wolof_translate/utils/download_model.py +40 -0
  29. wolof_translate/utils/evaluate_custom.py +147 -0
  30. wolof_translate/utils/evaluation.py +74 -0
  31. wolof_translate/utils/extract_new_sentences.py +810 -0
  32. wolof_translate/utils/extract_poems.py +60 -0
  33. wolof_translate/utils/extract_sentences.py +562 -0
  34. wolof_translate/utils/improvements/__init__.py +0 -0
  35. wolof_translate/utils/improvements/end_marks.py +45 -0
  36. wolof_translate/utils/recuperate_datasets.py +94 -0
  37. wolof_translate/utils/recuperate_datasets_trunc.py +85 -0
  38. wolof_translate/utils/send_model.py +26 -0
  39. wolof_translate/utils/sent_corrections.py +169 -0
  40. wolof_translate/utils/sent_transformers.py +27 -0
  41. wolof_translate/utils/sent_unification.py +97 -0
  42. wolof_translate/utils/split_with_valid.py +72 -0
  43. wolof_translate/utils/tokenize_text.py +46 -0
  44. wolof_translate/utils/training.py +213 -0
  45. wolof_translate/utils/trunc_hg_training.py +196 -0
  46. wolof_translate-0.0.1.dist-info/METADATA +31 -0
  47. wolof_translate-0.0.1.dist-info/RECORD +49 -0
  48. wolof_translate-0.0.1.dist-info/WHEEL +5 -0
  49. wolof_translate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,865 @@
1
+ # %%writefile wolof-translate/wolof_translate/models/transformers/main.py
2
+ from wolof_translate.models.transformers.position import PositionalEncoding
3
+
4
+ # from wolof_translate.models.transformers.size import SizePredict
5
+ from torch.nn.utils.rnn import pad_sequence
6
+ from torch.nn import functional as F
7
+ from torch import nn
8
+ from typing import *
9
+ import torch
10
+ import copy
11
+
12
+ # new Exception for that transformer
13
+ class TargetException(Exception):
14
+ def __init__(self, error):
15
+
16
+ print(error)
17
+
18
+
19
+ class GenerationException(Exception):
20
+ def __init__(self, error):
21
+
22
+ print(error)
23
+
24
+
25
+ class Transformer(nn.Module):
26
+ def __init__(
27
+ self,
28
+ vocab_size: int,
29
+ encoder,
30
+ decoder,
31
+ class_criterion=nn.CrossEntropyLoss(label_smoothing=0.1),
32
+ # size_criterion = nn.MSELoss(),
33
+ # n_features: int = 100,
34
+ # n_layers: int = 2,
35
+ n_poses_max: int = 2000,
36
+ projection_type: str = "embedding",
37
+ max_len: int = 20,
38
+ share_weight: bool = False,
39
+ ):
40
+
41
+ super(Transformer, self).__init__()
42
+
43
+ assert len(encoder.layers) > 0 and len(decoder.layers) > 0
44
+
45
+ self.dropout = encoder.layers._modules["0"].dropout.p
46
+
47
+ self.enc_embed_dim = encoder.layers._modules["0"].linear1.in_features
48
+
49
+ self.dec_embed_dim = decoder.layers._modules["0"].linear1.in_features
50
+
51
+ # we can initiate the positional encoding model
52
+ self.pe = PositionalEncoding(n_poses_max, self.enc_embed_dim)
53
+
54
+ if projection_type == "embedding":
55
+
56
+ self.embedding_layer = nn.Embedding(vocab_size, self.enc_embed_dim)
57
+
58
+ elif projection_type == "linear":
59
+
60
+ self.embedding_layer = nn.Linear(vocab_size, self.enc_embed_dim)
61
+
62
+ # initialize the first encoder and decoder
63
+ self.encoder = encoder
64
+
65
+ self.decoder = decoder
66
+
67
+ self.class_criterion = class_criterion
68
+
69
+ # add dropout to the inputs and outputs of the encoder and decoder
70
+ self.in_dropout = nn.Dropout(p=self.dropout)
71
+
72
+ self.out_dropout = nn.Dropout(p=self.dropout)
73
+
74
+ # self.size_criterion = size_criterion
75
+
76
+ # let's initiate the mlp for predicting the target size
77
+ # self.size_prediction = SizePredict(
78
+ # self.enc_embed_dim,
79
+ # n_features=n_features,
80
+ # n_layers=n_layers,
81
+ # normalization=True, # we always use normalization
82
+ # drop_out=self.dropout
83
+ # )
84
+
85
+ self.classifier = nn.Linear(self.dec_embed_dim, vocab_size)
86
+
87
+ # let us share the weights between the embedding layer and classification
88
+ # linear layer
89
+ if share_weight:
90
+
91
+ self.embedding_layer.register_forward_hook(self._copy_embedding_weights)
92
+
93
+ self.max_len = max_len
94
+
95
+ def forward(
96
+ self,
97
+ input_,
98
+ input_mask=None,
99
+ target=None,
100
+ target_mask=None,
101
+ pad_token_id: int = 3,
102
+ ):
103
+
104
+ # ---> Encoder prediction
105
+ input_embed = self.embedding_layer(input_)
106
+
107
+ # recuperate the last input (before position)
108
+ last_input = input_embed[:, -1:]
109
+
110
+ # add position to input_embedding
111
+ input_embed = self.pe(input_embed)
112
+
113
+ # recuperate the input mask for pytorch encoder
114
+ pad_mask1 = (
115
+ (input_mask == 0).to(next(self.parameters()).device, dtype=torch.bool)
116
+ if not input_mask is None
117
+ else None
118
+ )
119
+
120
+ # let us compute the states
121
+ input_embed = input_embed.type_as(next(self.encoder.parameters()))
122
+
123
+ input_embed = self.in_dropout(input_embed) # apply dropout to the input embed
124
+
125
+ states = self.encoder(input_embed, src_key_padding_mask=pad_mask1)
126
+
127
+ # apply dropout to the states
128
+ states = self.out_dropout(states)
129
+
130
+ # ---> Decoder prediction
131
+ # let's predict the size of the target
132
+ # target_size = self.size_prediction(states).mean(axis = 1)
133
+
134
+ target_embed = self.embedding_layer(target)
135
+
136
+ # recuperate target mask for pytorch decoder
137
+ pad_mask2 = (
138
+ (target_mask == 0).to(next(self.parameters()).device, dtype=torch.bool)
139
+ if not target_mask is None
140
+ else None
141
+ )
142
+
143
+ # define the attention mask
144
+ targ_mask = self.get_target_mask(target_embed.size(1))
145
+
146
+ # let's concatenate the last input and the target shifted from one position to the right (new seq dim = target seq dim)
147
+ target_embed = torch.cat((last_input, target_embed[:, :-1]), dim=1)
148
+
149
+ # add position to target embed
150
+ target_embed = self.pe(target_embed)
151
+
152
+ # we pass all of the shifted target sequence to the decoder if training mode
153
+ if self.training:
154
+
155
+ target_embed = target_embed.type_as(next(self.encoder.parameters()))
156
+
157
+ # add dropout to the target
158
+ target_embed = self.in_dropout(target_embed)
159
+
160
+ outputs = self.decoder(
161
+ target_embed, states, tgt_mask=targ_mask, tgt_key_padding_mask=pad_mask2
162
+ )
163
+
164
+ # add dropout to the outputs
165
+ outputs = self.out_dropout(outputs)
166
+
167
+ else: ## This part was understand with the help of the professor Bousso.
168
+
169
+ # if we are in evaluation mode we will not use the target but the outputs to make prediction and it is
170
+ # sequentially done (see comments)
171
+
172
+ # let us recuperate the last input as the current outputs
173
+ outputs = last_input.type_as(next(self.encoder.parameters()))
174
+
175
+ # for each target that we want to predict
176
+ for t in range(target.size(1)):
177
+
178
+ # recuperate the target mask of the current decoder input
179
+ current_targ_mask = targ_mask[
180
+ : t + 1, : t + 1
181
+ ] # all attentions between the elements before the last target
182
+
183
+ # we do the same for the padding mask
184
+ current_pad_mask = None
185
+
186
+ if not pad_mask2 is None:
187
+
188
+ current_pad_mask = pad_mask2[:, : t + 1]
189
+
190
+ # make new predictions
191
+ out = self.decoder(
192
+ outputs,
193
+ states,
194
+ tgt_mask=current_targ_mask,
195
+ tgt_key_padding_mask=current_pad_mask,
196
+ )
197
+
198
+ # add the last new prediction to the decoder inputs
199
+ outputs = torch.cat(
200
+ (outputs, out[:, -1:]), dim=1
201
+ ) # the prediction of the last output is the last to add (!)
202
+
203
+ # let's take only the predictions (the last input will not be taken)
204
+ outputs = outputs[:, 1:]
205
+
206
+ # let us add padding index to the outputs
207
+ if not target_mask is None:
208
+ target = copy.deepcopy(target.cpu())
209
+ target = target.to(target_mask.device).masked_fill_(target_mask == 0, -100)
210
+
211
+ # ---> Loss Calculation
212
+ # let us calculate the loss of the size prediction
213
+ # size_loss = 0
214
+ # if not self.size_criterion is None:
215
+
216
+ # size_loss = self.size_criterion(target_size, target_mask.sum(axis = -1).unsqueeze(1).type_as(next(self.parameters())))
217
+
218
+ outputs = self.classifier(outputs)
219
+
220
+ # let us permute the two last dimensions of the outputs
221
+ outputs_ = outputs.permute(0, -1, -2)
222
+
223
+ # calculate the loss
224
+ loss = self.class_criterion(outputs_, target)
225
+
226
+ outputs = torch.softmax(outputs, dim=-1)
227
+
228
+ # calculate the predictionos
229
+ outputs = copy.deepcopy(outputs.detach().cpu())
230
+ predictions = (
231
+ torch.argmax(outputs, dim=-1)
232
+ .to(target_mask.device)
233
+ .masked_fill_(target_mask == 0, pad_token_id)
234
+ )
235
+
236
+ return {"loss": loss, "preds": predictions}
237
+
238
+ def generate(
239
+ self,
240
+ input_,
241
+ input_mask=None,
242
+ temperature: float = 0,
243
+ max_len: Union[int, None] = None,
244
+ ):
245
+
246
+ if self.training:
247
+
248
+ raise GenerationException(
249
+ "You cannot generate when the model is on training mode!"
250
+ )
251
+
252
+ # recuperate the max len
253
+ max_len = max_len if not max_len is None else self.max_len
254
+
255
+ # ---> Encoder prediction
256
+ input_embed = self.embedding_layer(input_)
257
+
258
+ # recuperate the last input (before position)
259
+ last_input = input_embed[:, -1:]
260
+
261
+ # add position to input_embedding
262
+ input_embed = self.pe(input_embed)
263
+
264
+ # recuperate the input mask for pytorch encoder
265
+ pad_mask1 = (
266
+ (input_mask == False).to(next(self.parameters()).device)
267
+ if not input_mask is None
268
+ else None
269
+ )
270
+
271
+ # let us compute the states
272
+ input_embed = input_embed.type_as(next(self.encoder.parameters()))
273
+
274
+ states = self.encoder(input_embed, src_key_padding_mask=pad_mask1)
275
+
276
+ # ---> Decoder prediction
277
+ # let us recuperate the maximum length
278
+ # max_len = self.max_len if not self.max_len is None else 0
279
+
280
+ # let's predict the size of the target and the target mask
281
+ # if max_len > 0:
282
+
283
+ # target_size = self.size_prediction(states).mean(axis = 1).round().clip(1, max_len)
284
+
285
+ # else:
286
+
287
+ # target_size = torch.max(self.size_prediction(states).mean(axis = 1).round(), torch.tensor(1.0))
288
+
289
+ # target_ = copy.deepcopy(target_size.cpu())
290
+
291
+ # target_mask = [torch.tensor(int(size[0])*[1] + [0] * max(max_len - int(size[0]), 0)) for size in target_.tolist()]
292
+
293
+ # if max_len > 0:
294
+
295
+ # target_mask = torch.stack(target_mask).to(next(self.parameters()).device, dtype = torch.bool)
296
+
297
+ # else:
298
+
299
+ # target_mask = pad_sequence(target_, batch_first = True).to(next(self.parameters()).device, dtype = torch.bool)
300
+
301
+ # recuperate target mask for pytorch decoder
302
+ # pad_mask2 = (target_mask == 0).to(next(self.parameters()).device, dtype = torch.bool) if not target_mask is None else None
303
+
304
+ # define the attention mask
305
+ targ_mask = self.get_target_mask(max_len)
306
+
307
+ # if we are in evaluation mode we will not use the target but the outputs to make prediction and it is
308
+ # sequentially done (see comments)
309
+
310
+ # let us recuperate the last input as the current outputs
311
+ outputs = last_input.type_as(next(self.encoder.parameters()))
312
+
313
+ # for each target that we want to predict
314
+ for t in range(max_len):
315
+
316
+ # recuperate the target mask of the current decoder input
317
+ current_targ_mask = targ_mask[
318
+ : t + 1, : t + 1
319
+ ] # all attentions between the elements before the last target
320
+
321
+ # we do the same for the padding mask
322
+ current_pad_mask = None
323
+
324
+ # if not pad_mask2 is None:
325
+
326
+ # current_pad_mask = pad_mask2[:, :t+1]
327
+
328
+ # make new predictions
329
+ out = self.decoder(
330
+ outputs,
331
+ states,
332
+ tgt_mask=current_targ_mask,
333
+ tgt_key_padding_mask=current_pad_mask,
334
+ )
335
+
336
+ # add the last new prediction to the decoder inputs
337
+ outputs = torch.cat(
338
+ (outputs, out[:, -1:]), dim=1
339
+ ) # the prediction of the last output is the last to add (!)
340
+
341
+ # let's take only the predictions (the last input will not be taken)
342
+ outputs = outputs[:, 1:]
343
+
344
+ # ---> Predictions
345
+ outputs = self.classifier(outputs)
346
+
347
+ # calculate the resulted outputs with temperature
348
+ if temperature > 0:
349
+
350
+ outputs = torch.softmax(outputs / temperature, dim=-1)
351
+
352
+ else:
353
+
354
+ outputs = torch.softmax(outputs, dim=-1)
355
+
356
+ # calculate the predictionos
357
+ outputs = copy.deepcopy(outputs.detach().cpu())
358
+ predictions = torch.argmax(outputs, dim=-1).to(next(self.parameters()).device)
359
+
360
+ return predictions
361
+
362
+ def generate_(
363
+ self,
364
+ input_,
365
+ input_mask=None,
366
+ temperature: float = 0,
367
+ max_len: Union[int, None] = None,
368
+ ):
369
+
370
+ if self.training:
371
+
372
+ raise GenerationException(
373
+ "You cannot generate when the model is on training mode!"
374
+ )
375
+
376
+ # recuperate the max len
377
+ max_len = max_len if not max_len is None else self.max_len
378
+
379
+ # ---> Encoder prediction
380
+ input_embed = self.embedding_layer(input_)
381
+
382
+ # recuperate the last input (before position)
383
+ last_input = input_[:, -1:]
384
+
385
+ # add position to input_embedding
386
+ input_embed = self.pe(input_embed)
387
+
388
+ # recuperate the input mask for pytorch encoder
389
+ pad_mask1 = (
390
+ (input_mask == False).to(next(self.parameters()).device)
391
+ if not input_mask is None
392
+ else None
393
+ )
394
+
395
+ # let us compute the states
396
+ input_embed = input_embed.type_as(next(self.encoder.parameters()))
397
+
398
+ states = self.encoder(input_embed, src_key_padding_mask=pad_mask1)
399
+
400
+ # define the attention mask
401
+ targ_mask = self.get_target_mask(max_len)
402
+
403
+ # if we are in evaluation mode we will not use the target but the outputs to make prediction and it is
404
+ # sequentially done (see comments)
405
+
406
+ # let us recuperate the last input as the current outputs
407
+ tokens = last_input
408
+
409
+ # for each target that we want to predict
410
+ for t in range(max_len):
411
+
412
+ # recuperate the target mask of the current decoder input
413
+ current_targ_mask = targ_mask[
414
+ : t + 1, : t + 1
415
+ ] # all attentions between the elements before the last target
416
+
417
+ # we do the same for the padding mask
418
+ current_pad_mask = None
419
+
420
+ # if not pad_mask2 is None:
421
+
422
+ # current_pad_mask = pad_mask2[:, :t+1]
423
+
424
+ # pass the tokens to the embedding layer to get the embeddings
425
+ tokens_embed = self.pe(self.embedding_layer(tokens)).type_as(
426
+ next(self.encoder.parameters())
427
+ )
428
+
429
+ # make new predictions
430
+ out = self.decoder(
431
+ tokens_embed,
432
+ states,
433
+ tgt_mask=current_targ_mask,
434
+ tgt_key_padding_mask=current_pad_mask,
435
+ )
436
+
437
+ # recuperate probabilities with or without temperature
438
+ if temperature > 0:
439
+
440
+ probs = torch.softmax(self.classifier(out[:, -1]) / temperature, dim=-1)
441
+
442
+ else:
443
+
444
+ probs = torch.softmax(self.classifier(out[:, -1]), dim=-1)
445
+
446
+ # let us sample the next token
447
+ next_token = torch.multinomial(probs, num_samples=1)
448
+
449
+ # add the last new prediction to the decoder inputs
450
+ tokens = torch.cat(
451
+ (tokens, next_token), dim=-1
452
+ ) # the prediction of the last output is the last to add (!)
453
+
454
+ # let's take only the predictions (the last input will not be taken)
455
+ predictions = tokens[:, 1:]
456
+
457
+ return predictions
458
+
459
+ def beam_generate(
460
+ self,
461
+ input_,
462
+ input_mask=None,
463
+ temperature: float = 0,
464
+ max_len: Union[int, None] = None,
465
+ beam_size: int = 5,
466
+ ):
467
+
468
+ # let us initialize the batch size
469
+ batch_size = input_.size(0)
470
+
471
+ if self.training:
472
+
473
+ raise GenerationException(
474
+ "You cannot generate when the model is on training mode!"
475
+ )
476
+
477
+ # recuperate the max len
478
+ max_len = max_len if not max_len is None else self.max_len
479
+
480
+ # ---> Encoder prediction
481
+ input_embed = self.embedding_layer(input_)
482
+
483
+ # recuperate the last input (before position)
484
+ last_input = input_[:, -1:]
485
+
486
+ # add position to input_embedding
487
+ input_embed = self.pe(input_embed)
488
+
489
+ # recuperate the input mask for pytorch encoder
490
+ pad_mask1 = (
491
+ (input_mask == False).to(next(self.parameters()).device)
492
+ if not input_mask is None
493
+ else None
494
+ )
495
+
496
+ # let us compute the states
497
+ input_embed = input_embed.type_as(next(self.encoder.parameters()))
498
+
499
+ states = self.encoder(input_embed, src_key_padding_mask=pad_mask1)
500
+
501
+ # define the attention mask
502
+ targ_mask = self.get_target_mask(max_len)
503
+
504
+ # if we are in evaluation mode we will not use the target but the outputs to make prediction and it is
505
+ # sequentially done (see comments)
506
+
507
+ # let us recuperate the last input as the current outputs
508
+ tokens = last_input
509
+
510
+ # generate predictions (beam search with the help of chatgpt)
511
+
512
+ # let us initialize the beams
513
+ beams = [tokens[i, -1:].expand(beam_size, -1) for i in range(batch_size)]
514
+
515
+ # initialize the beam scores
516
+ scores = torch.zeros(
517
+ (batch_size, beam_size), device=next(self.parameters()).device
518
+ )
519
+
520
+ # for each target that we want to predict
521
+ for t in range(max_len):
522
+
523
+ # initialize all of the candidates and the scores
524
+ all_candidates = []
525
+ all_scores = []
526
+
527
+ # recuperate the target mask of the current decoder input
528
+ current_targ_mask = targ_mask[
529
+ : t + 1, : t + 1
530
+ ] # all attentions between the elements before the last target
531
+
532
+ # we do the same for the padding mask
533
+ current_pad_mask = None
534
+
535
+ # iterate over the beams and batches to calculate make predictions
536
+ for be_idx in range(beam_size):
537
+
538
+ # initialize the candidates and scores
539
+ candidates = []
540
+ candidate_scores = []
541
+
542
+ for ba_idx in range(batch_size):
543
+
544
+ # recuperate the current state
545
+ current_state = states[ba_idx].unsqueeze(0)
546
+
547
+ # recuperate the current sequence
548
+ tokens = beams[ba_idx][be_idx].unsqueeze(0)
549
+
550
+ # pass the tokens to the embedding layer to get the embeddings
551
+ tokens_embed = self.pe(self.embedding_layer(tokens)).type_as(
552
+ next(self.encoder.parameters())
553
+ )
554
+
555
+ # make new predictions
556
+ out = self.decoder(
557
+ tokens_embed,
558
+ current_state,
559
+ tgt_mask=current_targ_mask,
560
+ tgt_key_padding_mask=current_pad_mask,
561
+ )
562
+
563
+ # recuperate probabilities with or without temperature
564
+ if temperature > 0:
565
+
566
+ log_probs = F.log_softmax(
567
+ self.classifier(out[:, -1]).squeeze() / temperature, dim=-1
568
+ )
569
+
570
+ else:
571
+
572
+ log_probs = F.log_softmax(
573
+ self.classifier(out[:, -1]).squeeze(), dim=-1
574
+ )
575
+
576
+ # get top k candidates
577
+ beam_scores, beam_candidates = log_probs.topk(beam_size, dim=-1)
578
+
579
+ # add the candidates to the set of candidates (do the same for the scores)
580
+ candidates.append(beam_candidates)
581
+ candidate_scores.append(beam_scores)
582
+
583
+ # add the current set of candidates and scores to the global set
584
+ all_candidates.append(torch.stack(candidates))
585
+ all_scores.append(torch.stack(candidate_scores))
586
+
587
+ # select top k candidates and scores from all beams
588
+ all_candidates = torch.stack(all_candidates)
589
+ all_scores = torch.stack(all_scores)
590
+ topk_scores, topk_idx = all_scores.view(batch_size, -1).topk(beam_size)
591
+
592
+ # Update beams and scores for the current iteration
593
+ new_beams = []
594
+ new_scores = []
595
+
596
+ # iterate over the batches to update the beams and scores
597
+ for ba_idx in range(batch_size):
598
+
599
+ # recuperate candidates
600
+ beam_candidates = all_candidates[:, ba_idx].reshape(-1)
601
+
602
+ # recuperate indices
603
+ selected_indices = topk_idx[ba_idx]
604
+
605
+ # recuperate the beams
606
+ selected_beams = selected_indices // beam_size
607
+
608
+ # recuperate the tokens
609
+ selected_tokens = beam_candidates[selected_indices % beam_size]
610
+
611
+ new_beams.append(
612
+ [
613
+ torch.concatenate(
614
+ (
615
+ beams[ba_idx][selected_beams[i]],
616
+ selected_tokens[i].unsqueeze(0),
617
+ ),
618
+ dim=-1,
619
+ )
620
+ for i in range(beam_size)
621
+ ]
622
+ )
623
+
624
+ new_scores.append(topk_scores[ba_idx])
625
+
626
+ # update the beams and scores
627
+ beams = new_beams
628
+ scores = new_scores
629
+
630
+ # recuperate the top candidates for each sequence in the batch
631
+ predictions = torch.stack([beams[i][0].squeeze() for i in range(batch_size)])
632
+
633
+ # let's take only the predictions (the last input will not be taken)
634
+ predictions = predictions[:, 1:]
635
+
636
+ return predictions
637
+
638
+ def diverse_beam_generate(
639
+ self,
640
+ input_,
641
+ input_mask=None,
642
+ temperature: float = 0,
643
+ max_len: Union[int, None] = None,
644
+ beam_size: int = 5,
645
+ beam_groups: int = 1,
646
+ diversity_penalty: float = 0.5,
647
+ ):
648
+
649
+ # let us initialize the batch size
650
+ batch_size = input_.size(0)
651
+
652
+ if self.training:
653
+
654
+ raise GenerationException(
655
+ "You cannot generate when the model is on training mode!"
656
+ )
657
+
658
+ # recuperate the max len
659
+ max_len = max_len if not max_len is None else self.max_len
660
+
661
+ # ---> Encoder prediction
662
+ input_embed = self.embedding_layer(input_)
663
+
664
+ # recuperate the last input (before position)
665
+ last_input = input_[:, -1:]
666
+
667
+ # add position to input_embedding
668
+ input_embed = self.pe(input_embed)
669
+
670
+ # recuperate the input mask for pytorch encoder
671
+ pad_mask1 = (
672
+ (input_mask == False).to(next(self.parameters()).device)
673
+ if not input_mask is None
674
+ else None
675
+ )
676
+
677
+ # let us compute the states
678
+ input_embed = input_embed.type_as(next(self.encoder.parameters()))
679
+
680
+ states = self.encoder(input_embed, src_key_padding_mask=pad_mask1)
681
+
682
+ # define the attention mask
683
+ targ_mask = self.get_target_mask(max_len)
684
+
685
+ # if we are in evaluation mode we will not use the target but the outputs to make prediction and it is
686
+ # sequentially done (see comments)
687
+
688
+ # let us recuperate the last input as the current outputs
689
+ tokens = last_input
690
+
691
+ # generate predictions (beam search with the help of chatgpt)
692
+
693
+ # let us initialize the beams
694
+ beams = [tokens[i, -1:].expand(beam_size, -1) for i in range(batch_size)]
695
+
696
+ # initialize the beam scores
697
+ scores = torch.zeros(
698
+ (batch_size, beam_size), device=next(self.parameters()).device
699
+ )
700
+
701
+ # for each target that we want to predict
702
+ for t in range(max_len):
703
+
704
+ # initialize all of the candidates and the scores
705
+ all_candidates = []
706
+ all_scores = []
707
+
708
+ # recuperate the target mask of the current decoder input
709
+ current_targ_mask = targ_mask[
710
+ : t + 1, : t + 1
711
+ ] # all attentions between the elements before the last target
712
+
713
+ # we do the same for the padding mask
714
+ current_pad_mask = None
715
+
716
+ # iterate over the beams and batches to calculate make predictions
717
+ for be_idx in range(beam_size):
718
+
719
+ # initialize the candidates and scores
720
+ candidates = []
721
+ candidate_scores = []
722
+
723
+ for ba_idx in range(batch_size):
724
+
725
+ # recuperate the current state
726
+ current_state = states[ba_idx].unsqueeze(0)
727
+
728
+ # recuperate the current sequence
729
+ tokens = beams[ba_idx][be_idx].unsqueeze(0)
730
+
731
+ # pass the tokens to the embedding layer to get the embeddings
732
+ tokens_embed = self.pe(self.embedding_layer(tokens)).type_as(
733
+ next(self.encoder.parameters())
734
+ )
735
+
736
+ # make new predictions
737
+ out = self.decoder(
738
+ tokens_embed,
739
+ current_state,
740
+ tgt_mask=current_targ_mask,
741
+ tgt_key_padding_mask=current_pad_mask,
742
+ )
743
+
744
+ # recuperate probabilities with or without temperature
745
+ if temperature > 0:
746
+
747
+ log_probs = F.log_softmax(
748
+ self.classifier(out[:, -1]).squeeze() / temperature, dim=-1
749
+ )
750
+
751
+ else:
752
+
753
+ log_probs = F.log_softmax(
754
+ self.classifier(out[:, -1]).squeeze(), dim=-1
755
+ )
756
+
757
+ # get top k candidates
758
+ beam_scores, beam_candidates = log_probs.topk(beam_size, dim=-1)
759
+
760
+ # add the candidates to the set of candidates (do the same for the scores)
761
+ candidates.append(beam_candidates)
762
+ candidate_scores.append(beam_scores)
763
+
764
+ # add the current set of candidates and scores to the global set
765
+ all_candidates.append(torch.stack(candidates))
766
+ all_scores.append(torch.stack(candidate_scores))
767
+
768
+ # select top k candidates and scores from all beams
769
+ all_candidates = torch.stack(all_candidates)
770
+ all_scores = torch.stack(all_scores)
771
+
772
+ # reshape candidates and scores for efficient matrix operations
773
+ all_candidates_flat = all_candidates.reshape(beam_size, -1)
774
+ all_scores_flat = all_scores.reshape(beam_size, -1)
775
+
776
+ # apply the diversity penalty to the scores for each beam group
777
+ group_size = beam_size // beam_groups
778
+ for group_idx in range(beam_groups):
779
+
780
+ group_start = group_idx * group_size
781
+
782
+ group_end = (group_idx + 1) * group_size
783
+
784
+ group_candidates = all_candidates_flat[group_start:group_end]
785
+
786
+ group_scores = all_scores_flat[group_start:group_end]
787
+
788
+ diversity_penalty_ = self.hamming_distance(
789
+ group_candidates.unsqueeze(2), group_candidates.unsqueeze(1)
790
+ )
791
+
792
+ penalty = diversity_penalty * diversity_penalty_.view(
793
+ group_size, -1
794
+ ).sum(dim=0)
795
+
796
+ group_scores -= penalty
797
+
798
+ # reshape the scores back to the original shape
799
+ all_scores = all_scores_flat.reshape(beam_size, batch_size, -1)
800
+
801
+ topk_scores, topk_idx = all_scores.view(batch_size, -1).topk(beam_size)
802
+
803
+ # Update beams and scores for the current iteration
804
+ new_beams = []
805
+ new_scores = []
806
+
807
+ # iterate over the batches to update the beams and scores
808
+ for ba_idx in range(batch_size):
809
+
810
+ # recuperate candidates
811
+ beam_candidates = all_candidates[:, ba_idx].reshape(-1)
812
+
813
+ # recuperate indices
814
+ selected_indices = topk_idx[ba_idx]
815
+
816
+ # recuperate the beams
817
+ selected_beams = selected_indices // beam_size
818
+
819
+ # recuperate the tokens
820
+ selected_tokens = beam_candidates[selected_indices % beam_size]
821
+
822
+ new_beams.append(
823
+ [
824
+ torch.concatenate(
825
+ (
826
+ beams[ba_idx][selected_beams[i]],
827
+ selected_tokens[i].unsqueeze(0),
828
+ ),
829
+ dim=-1,
830
+ )
831
+ for i in range(beam_size)
832
+ ]
833
+ )
834
+
835
+ new_scores.append(topk_scores[ba_idx])
836
+
837
+ # update the beams and scores
838
+ beams = new_beams
839
+ scores = new_scores
840
+
841
+ # recuperate the top candidates for each sequence in the batch
842
+ predictions = torch.stack([beams[i][0].squeeze() for i in range(batch_size)])
843
+
844
+ # let's take only the predictions (the last input will not be taken)
845
+ predictions = predictions[:, 1:]
846
+
847
+ return predictions
848
+
849
+ def hamming_distance(self, sequence_1, sequence_2):
850
+
851
+ # Calculate the hamming distance between two sequences
852
+ return (sequence_1 != sequence_2).sum(axis=-1)
853
+
854
+ def get_target_mask(self, attention_size: int):
855
+
856
+ return torch.triu(torch.ones((attention_size, attention_size)), diagonal=1).to(
857
+ next(self.parameters()).device, dtype=torch.bool
858
+ )
859
+
860
+ def _copy_embedding_weights(self, module, input, output):
861
+ # Copy the embedding weights to the last dense layer
862
+ self.classifier.weight.data = module.weight.data
863
+
864
+
865
+ # %%