translate-package 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,8 @@ from translate_package.models.gradient_observation import get_gradients_mean
20
20
 
21
21
  from translate_package.models.lstm import LSTMSequenceToSequence
22
22
 
23
+ import os
24
+
23
25
 
24
26
  def print_number_of_trainable_model_parameters(model):
25
27
  trainable_model_params = 0
@@ -326,3 +328,20 @@ class MachineTranslationTransformer(pl.LightningModule):
326
328
  self.log_dict(
327
329
  metrics, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True
328
330
  )
331
+
332
+ def save_model(self, directory: str = "my_model", model_name = "pytorch_model.bin"):
333
+
334
+ if not os.path.exists(directory):
335
+
336
+ os.makedirs(directory)
337
+
338
+ torch.save(self.model.state_dict(), os.path.join(directory, model_name))
339
+
340
+ if hasattr(self.model, "config"):
341
+
342
+ self.model.config.to_json_file(os.path.join(directory, "config.json"))
343
+
344
+ self.tokenizer.save_pretrained(directory)
345
+
346
+
347
+
@@ -0,0 +1,15 @@
1
+ from huggingface_hub import login, HfApi, upload_folder, create_repo
2
+
3
+
4
+ def upload_model(hub_token, directory = "my_model", username = "", repo_name = "", commit_message = "new model created"):
5
+
6
+ repo_id = f"{username}/{repo_name}"
7
+
8
+ login(token=hub_token)
9
+
10
+ create_repo(repo_id)
11
+
12
+ upload_folder(repo_id = repo_id, folder_path = directory, commit_message= commit_message)
13
+
14
+ print(f"Model was successfully upload to {repo_id}.")
15
+
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: translate-package
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: Contain functions and classes to efficiently train a sequence to sequence to translate between two languages.
5
5
  Author: Oumar Kane
6
6
  Author-email: oumar.kane@univ-thies.sn
7
7
  Requires-Dist: accelerate
8
- Requires-Dist: torch
8
+ Requires-Dist: torch (==2.6.0+cu124)
9
9
  Requires-Dist: spacy
10
10
  Requires-Dist: nltk
11
11
  Requires-Dist: gensim
@@ -5,13 +5,14 @@ translate_package/errors/__init__.py,sha256=gu6XjAIghG4lLkYo8x_7_yyLRtK2FIvmC-Wc
5
5
  translate_package/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  translate_package/models/gradient_observation.py,sha256=P91UA5i-RdkK46TqpPOJ54DsUYgTI9cRohgPS1Ch0Lc,294
7
7
  translate_package/models/lstm.py,sha256=OPkvvceowz5JqdGGH4cfPhH23kbP11z-29zIJn5d8ig,3273
8
- translate_package/models/machine_translation.py,sha256=e59c88IDElLxt8dTlgziJ9OGThMbwTNtixBjK94gkAQ,10612
8
+ translate_package/models/machine_translation.py,sha256=1ot9Me6U1O7UHJMuJGvatx3DxoKY9TghzzHNzxdZa5g,11170
9
9
  translate_package/tokenization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  translate_package/tokenization/load_tokenizer.py,sha256=g8j5pDmimFhwjpeYNkWot0hXMzAqqURbtedcQK-1xYE,1543
11
11
  translate_package/tokenization/train_tokenizer.py,sha256=RkdT5DUx201OBNaswM6m54iqcrmCThd3ITLguQb_zVM,3347
12
12
  translate_package/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  translate_package/utils/checkpoint.py,sha256=GqymRvF8_QZgrQq9m79Ppj6Qr7NQm78kDARm3p_chC0,322
14
- translate_package-0.1.3.dist-info/METADATA,sha256=1Z2i5sbtnvm3MrQCIklpdByQRXv9Iwg4Ug1O0M-ekew,850
15
- translate_package-0.1.3.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
16
- translate_package-0.1.3.dist-info/top_level.txt,sha256=8e2HIrGAMzoSukqu2q929dOJMV1zGYKI_BAFwl-P7XU,18
17
- translate_package-0.1.3.dist-info/RECORD,,
14
+ translate_package/utils/upload_to_hughub.py,sha256=0qihZIAAUuJXfOZ23Njz0aWpDpe8twQNDGPplgrIfzA,480
15
+ translate_package-0.1.5.dist-info/METADATA,sha256=d7FdRXTjpyBQ0dNWC4whsHRqiVMiJA2pA3RVkk18Y90,866
16
+ translate_package-0.1.5.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
17
+ translate_package-0.1.5.dist-info/top_level.txt,sha256=8e2HIrGAMzoSukqu2q929dOJMV1zGYKI_BAFwl-P7XU,18
18
+ translate_package-0.1.5.dist-info/RECORD,,