translate-package 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,7 +48,7 @@ class MachineTranslationTransformer(pl.LightningModule):
48
48
  lora_alpha=32,
49
49
  lora_dropout=0.05,
50
50
  bias="none",
51
- max_new_tokens=200,
51
+ max_new_tokens=90,
52
52
  predict_with_generate=True,
53
53
  num_beams=0,
54
54
  use_peft=False,
@@ -56,7 +56,8 @@ class MachineTranslationTransformer(pl.LightningModule):
56
56
  num_layers=6,
57
57
  hidden_size=128,
58
58
  dropout=0.1,
59
- bidirectional=False
59
+ bidirectional=False,
60
+ length_penalty=1.2
60
61
  ):
61
62
 
62
63
  super().__init__()
@@ -131,6 +132,8 @@ class MachineTranslationTransformer(pl.LightningModule):
131
132
 
132
133
  self.num_beams = num_beams
133
134
 
135
+ self.length_penalty = length_penalty
136
+
134
137
  self.model_generation = model_generation
135
138
 
136
139
  self.predictions = {
@@ -263,17 +266,33 @@ class MachineTranslationTransformer(pl.LightningModule):
263
266
  )
264
267
 
265
268
  # generate predictions
266
- predictions = self.model.generate(
267
- input_ids=batch["input_ids"],
268
- attention_mask=batch["attention_mask"],
269
- max_new_tokens=self.max_new_tokens,
270
- do_sample=self.num_beams > 0,
271
- num_beams=self.num_beams
272
- ) if not self.model_generation in ["lstm"] else self.model.generate(
273
- input=batch["input_ids"],
274
- max_new_tokens=self.max_new_tokens,
275
- use_sampling=True
276
- )
269
+ if not self.model_generation in ["lstm"] and self.num_beams > 0:
270
+
271
+ predictions = self.model.generate(
272
+ input_ids=batch["input_ids"],
273
+ attention_mask=batch["attention_mask"],
274
+ max_new_tokens=self.max_new_tokens,
275
+ do_sample=True,
276
+ num_beams=self.num_beams,
277
+ length_penalty=self.length_penalty
278
+ )
279
+
280
+ elif not self.model_generation in ["lstm"]:
281
+
282
+ predictions = self.model.generate(
283
+ input_ids=batch["input_ids"],
284
+ attention_mask=batch["attention_mask"],
285
+ max_new_tokens=self.max_new_tokens,
286
+ do_sample=False
287
+ )
288
+
289
+ else:
290
+
291
+ predictions = self.model.generate(
292
+ input=batch["input_ids"],
293
+ max_new_tokens=self.max_new_tokens,
294
+ use_sampling=True
295
+ )
277
296
 
278
297
  # decode the labels
279
298
  predictions = self.tokenizer.batch_decode(predictions, skip_special_tokens=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: translate-package
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Contain functions and classes to efficiently train a sequence to sequence to translate between two languages.
5
5
  Author: Oumar Kane
6
6
  Author-email: oumar.kane@univ-thies.sn
@@ -5,13 +5,13 @@ translate_package/errors/__init__.py,sha256=gu6XjAIghG4lLkYo8x_7_yyLRtK2FIvmC-Wc
5
5
  translate_package/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  translate_package/models/gradient_observation.py,sha256=P91UA5i-RdkK46TqpPOJ54DsUYgTI9cRohgPS1Ch0Lc,294
7
7
  translate_package/models/lstm.py,sha256=OPkvvceowz5JqdGGH4cfPhH23kbP11z-29zIJn5d8ig,3273
8
- translate_package/models/machine_translation.py,sha256=5QQpjs_HR9mnPryMyfYpcMgU5tHAAj-eVrv3oGmjR5Y,9963
8
+ translate_package/models/machine_translation.py,sha256=LW1qNVAP-c2rKYC58L6cPQ4LU2ePD8wg_Z8CcHSI8MY,10575
9
9
  translate_package/tokenization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  translate_package/tokenization/load_tokenizer.py,sha256=vzCHS0ZDSJyr0y08zNvupMtD2jP8A16EBN-ob0LJHG0,1344
11
11
  translate_package/tokenization/train_tokenizer.py,sha256=RkdT5DUx201OBNaswM6m54iqcrmCThd3ITLguQb_zVM,3347
12
12
  translate_package/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  translate_package/utils/checkpoint.py,sha256=GqymRvF8_QZgrQq9m79Ppj6Qr7NQm78kDARm3p_chC0,322
14
- translate_package-0.0.2.dist-info/METADATA,sha256=OikhBGLwAfDae4ZhSbzNNdHVfj08UxIjSR_P6KryzMA,860
15
- translate_package-0.0.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
16
- translate_package-0.0.2.dist-info/top_level.txt,sha256=8e2HIrGAMzoSukqu2q929dOJMV1zGYKI_BAFwl-P7XU,18
17
- translate_package-0.0.2.dist-info/RECORD,,
14
+ translate_package-0.0.4.dist-info/METADATA,sha256=rT99Y10OFAqYCHBsjQiXxydlQhcFS2Ak-zLheAQJEE0,860
15
+ translate_package-0.0.4.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
16
+ translate_package-0.0.4.dist-info/top_level.txt,sha256=8e2HIrGAMzoSukqu2q929dOJMV1zGYKI_BAFwl-P7XU,18
17
+ translate_package-0.0.4.dist-info/RECORD,,