translate-package 0.0.8__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,7 +14,7 @@ from transformers import (
14
14
  Trainer, AutoModelForSeq2SeqLM,
15
15
  get_linear_schedule_with_warmup,
16
16
  T5ForConditionalGeneration, Adafactor, BartForConditionalGeneration,
17
- MT5ForConditionalGeneration, AdamWeightDecay
17
+ MT5ForConditionalGeneration, AdamWeightDecay, AutoTokenizer
18
18
  )
19
19
  from wolof_translate.utils.bucket_iterator import SequenceLengthBatchSampler, BucketSampler
20
20
  from wolof_translate.utils.sent_transformers import TransformerSequences
@@ -11,6 +11,7 @@ from translate_package import (
11
11
  MT5ForConditionalGeneration,
12
12
  BartForConditionalGeneration,
13
13
  AutoModelForSeq2SeqLM,
14
+ AutoTokenizer,
14
15
  Adafactor,
15
16
  AdamWeightDecay
16
17
  )
@@ -57,7 +58,7 @@ class MachineTranslationTransformer(pl.LightningModule):
57
58
  hidden_size=128,
58
59
  dropout=0.1,
59
60
  bidirectional=False,
60
- length_penalty=1.2
61
+ length_penalty=1.2,
61
62
  ):
62
63
 
63
64
  super().__init__()
@@ -157,7 +158,7 @@ class MachineTranslationTransformer(pl.LightningModule):
157
158
  warmup_init = False
158
159
  )
159
160
 
160
- elif self.model_generation in ["bart", "mbart"]:
161
+ elif self.model_generation in ["bart", "mbart", "nllb"]:
161
162
 
162
163
  optimizer = torch.optim.AdamW(
163
164
  self.parameters(), lr=self.lr, weight_decay=self.weight_decay
@@ -169,11 +170,11 @@ class MachineTranslationTransformer(pl.LightningModule):
169
170
  self.parameters(), lr=self.lr, weight_decay=self.weight_decay
170
171
  )
171
172
 
172
- if self.model_generation in ["t5", "lstm", "mt5"]:
173
+ if self.model_generation in ["t5", "lstm"]:
173
174
 
174
175
  return [optimizer]
175
176
 
176
- elif self.model_generation in ["bart"]:
177
+ elif self.model_generation in ["bart", "mt5"]:
177
178
 
178
179
  scheduler = get_linear_schedule_with_warmup(
179
180
  optimizer,
@@ -3,13 +3,18 @@ from transformers import T5TokenizerFast
3
3
  from transformers import AutoTokenizer
4
4
  import os
5
5
 
6
- def load_tokenizer(tokenizer_name, model, dir_path, file_name, model_name = None):
6
+ BCP_47_languages = {
7
+ 'french': 'fra_Latn',
8
+ 'wolof': 'wol_Latn',
9
+ }
10
+
11
+ def load_tokenizer(tokenizer_name, model, dir_path, file_name, model_name = None, src_lang = "french", tgt_lang = "wolof"):
7
12
 
8
13
  if model == "nllb":
9
14
 
10
15
  if not model_name is None:
11
16
 
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang = BCP_47_languages[src_lang], tgt_lang = BCP_47_languages[tgt_lang])
13
18
 
14
19
  print(f"The {model}'s tokenizer was successfully loaded")
15
20
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: translate-package
3
- Version: 0.0.8
3
+ Version: 0.1.0
4
4
  Summary: Contain functions and classes to efficiently train a sequence to sequence to translate between two languages.
5
5
  Author: Oumar Kane
6
6
  Author-email: oumar.kane@univ-thies.sn
@@ -1,17 +1,17 @@
1
- translate_package/__init__.py,sha256=Nckjm15LBEfKSU5-EBjfRkHJQhWELzwow0BD3rKtmkw,1297
1
+ translate_package/__init__.py,sha256=miie3aAeUYHsVk2O-kd4T86fFksuCiY70Eo6RNeY1Oo,1312
2
2
  translate_package/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  translate_package/data/data_preparation.py,sha256=AQFmIJawPL5kaWGhTGPCfizIQD00XqQ-cANaICZ-1Ow,14882
4
4
  translate_package/errors/__init__.py,sha256=gu6XjAIghG4lLkYo8x_7_yyLRtK2FIvmC-WcfJaeOlg,299
5
5
  translate_package/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  translate_package/models/gradient_observation.py,sha256=P91UA5i-RdkK46TqpPOJ54DsUYgTI9cRohgPS1Ch0Lc,294
7
7
  translate_package/models/lstm.py,sha256=OPkvvceowz5JqdGGH4cfPhH23kbP11z-29zIJn5d8ig,3273
8
- translate_package/models/machine_translation.py,sha256=j8ZXh9UEElixwy7RwBFE_sO0EZrkaMOGs7W39EWvXhc,10575
8
+ translate_package/models/machine_translation.py,sha256=a9YOEgsuGnkAvaEv8jML95YESNgkHu96ZXHm6_esv44,10604
9
9
  translate_package/tokenization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- translate_package/tokenization/load_tokenizer.py,sha256=vzCHS0ZDSJyr0y08zNvupMtD2jP8A16EBN-ob0LJHG0,1344
10
+ translate_package/tokenization/load_tokenizer.py,sha256=g8j5pDmimFhwjpeYNkWot0hXMzAqqURbtedcQK-1xYE,1543
11
11
  translate_package/tokenization/train_tokenizer.py,sha256=RkdT5DUx201OBNaswM6m54iqcrmCThd3ITLguQb_zVM,3347
12
12
  translate_package/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  translate_package/utils/checkpoint.py,sha256=GqymRvF8_QZgrQq9m79Ppj6Qr7NQm78kDARm3p_chC0,322
14
- translate_package-0.0.8.dist-info/METADATA,sha256=GiNYpQ6GrFumBn--bFa0cFijKTYjoMV2Wib8cm7R6Cg,850
15
- translate_package-0.0.8.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
16
- translate_package-0.0.8.dist-info/top_level.txt,sha256=8e2HIrGAMzoSukqu2q929dOJMV1zGYKI_BAFwl-P7XU,18
17
- translate_package-0.0.8.dist-info/RECORD,,
14
+ translate_package-0.1.0.dist-info/METADATA,sha256=l90EsBgKdqObcW_Q8w-2A3UpM-ji8Yn8V3QyLxUg-gk,850
15
+ translate_package-0.1.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
16
+ translate_package-0.1.0.dist-info/top_level.txt,sha256=8e2HIrGAMzoSukqu2q929dOJMV1zGYKI_BAFwl-P7XU,18
17
+ translate_package-0.1.0.dist-info/RECORD,,