nshtrainer 0.19.0__py3-none-any.whl → 0.19.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import io
2
2
  import logging
3
3
  import os
4
+ import re
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, cast
6
7
 
@@ -150,7 +151,32 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
150
151
  elif (username := api.whoami().get("name", None)) is None:
151
152
  raise ValueError("Could not get username from Hugging Face Hub.")
152
153
 
153
- return f"{username}/{root_config.project}-{root_config.run_name}-{root_config.id}"
154
+ # Sanitize the project (if it exists), run_name, and id
155
+ parts = []
156
+ if root_config.project:
157
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
158
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
159
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
160
+
161
+ # Combine parts and ensure it starts and ends with alphanumeric characters
162
+ repo_name = "-".join(parts)
163
+ repo_name = repo_name.strip("-")
164
+ repo_name = re.sub(
165
+ r"-+", "-", repo_name
166
+ ) # Replace multiple dashes with a single dash
167
+
168
+ # Ensure the name is not longer than 96 characters (excluding username)
169
+ if len(repo_name) > 96:
170
+ repo_name = repo_name[:96].rstrip("-")
171
+
172
+ # Ensure the repo name starts with an alphanumeric character
173
+ repo_name = re.sub(r"^[^a-zA-Z0-9]+", "", repo_name)
174
+
175
+ # If the repo_name is empty after all sanitization, use a default name
176
+ if not repo_name:
177
+ repo_name = "default-repo-name"
178
+
179
+ return f"{username}/{repo_name}"
154
180
 
155
181
 
156
182
  def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
@@ -1,7 +1,6 @@
1
1
  import logging
2
- from collections import abc
3
- from collections.abc import Callable, Iterable
4
- from typing import Any, TypeAlias, cast, final
2
+ from collections.abc import Callable, Iterable, Sequence
3
+ from typing import Any, TypeAlias, cast, final, overload
5
4
 
6
5
  from lightning.pytorch import Callback, LightningModule
7
6
  from lightning.pytorch.callbacks import LambdaCallback
@@ -19,11 +18,61 @@ class CallbackRegistrarModuleMixin:
19
18
  def __init__(self, *args, **kwargs):
20
19
  super().__init__(*args, **kwargs)
21
20
 
22
- self._ll_callbacks: list[CallbackFn] = []
21
+ self._nshtrainer_callbacks: list[CallbackFn] = []
22
+
23
+ @overload
24
+ def register_callback(
25
+ self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
26
+ ): ...
27
+
28
+ @overload
29
+ def register_callback(
30
+ self,
31
+ /,
32
+ *,
33
+ setup: Callable | None = None,
34
+ teardown: Callable | None = None,
35
+ on_fit_start: Callable | None = None,
36
+ on_fit_end: Callable | None = None,
37
+ on_sanity_check_start: Callable | None = None,
38
+ on_sanity_check_end: Callable | None = None,
39
+ on_train_batch_start: Callable | None = None,
40
+ on_train_batch_end: Callable | None = None,
41
+ on_train_epoch_start: Callable | None = None,
42
+ on_train_epoch_end: Callable | None = None,
43
+ on_validation_epoch_start: Callable | None = None,
44
+ on_validation_epoch_end: Callable | None = None,
45
+ on_test_epoch_start: Callable | None = None,
46
+ on_test_epoch_end: Callable | None = None,
47
+ on_validation_batch_start: Callable | None = None,
48
+ on_validation_batch_end: Callable | None = None,
49
+ on_test_batch_start: Callable | None = None,
50
+ on_test_batch_end: Callable | None = None,
51
+ on_train_start: Callable | None = None,
52
+ on_train_end: Callable | None = None,
53
+ on_validation_start: Callable | None = None,
54
+ on_validation_end: Callable | None = None,
55
+ on_test_start: Callable | None = None,
56
+ on_test_end: Callable | None = None,
57
+ on_exception: Callable | None = None,
58
+ on_save_checkpoint: Callable | None = None,
59
+ on_load_checkpoint: Callable | None = None,
60
+ on_before_backward: Callable | None = None,
61
+ on_after_backward: Callable | None = None,
62
+ on_before_optimizer_step: Callable | None = None,
63
+ on_before_zero_grad: Callable | None = None,
64
+ on_predict_start: Callable | None = None,
65
+ on_predict_end: Callable | None = None,
66
+ on_predict_batch_start: Callable | None = None,
67
+ on_predict_batch_end: Callable | None = None,
68
+ on_predict_epoch_start: Callable | None = None,
69
+ on_predict_epoch_end: Callable | None = None,
70
+ ): ...
23
71
 
24
72
  def register_callback(
25
73
  self,
26
74
  callback: Callback | Iterable[Callback] | CallbackFn | None = None,
75
+ /,
27
76
  *,
28
77
  setup: Callable | None = None,
29
78
  teardown: Callable | None = None,
@@ -109,7 +158,7 @@ class CallbackRegistrarModuleMixin:
109
158
  else:
110
159
  callback_ = callback
111
160
 
112
- self._ll_callbacks.append(callback_)
161
+ self._nshtrainer_callbacks.append(callback_)
113
162
 
114
163
 
115
164
  class CallbackModuleMixin(
@@ -130,13 +179,13 @@ class CallbackModuleMixin(
130
179
  if isinstance(module, CallbackRegistrarModuleMixin)
131
180
  )
132
181
  for module in modules:
133
- yield from module._ll_callbacks
182
+ yield from module._nshtrainer_callbacks
134
183
 
135
184
  @final
136
185
  @override
137
186
  def configure_callbacks(self):
138
187
  callbacks = super().configure_callbacks()
139
- if not isinstance(callbacks, abc.Sequence):
188
+ if not isinstance(callbacks, Sequence):
140
189
  callbacks = [callbacks]
141
190
 
142
191
  callbacks = list(callbacks)
@@ -145,7 +194,7 @@ class CallbackModuleMixin(
145
194
  if callback_result is None:
146
195
  continue
147
196
 
148
- if not isinstance(callback_result, abc.Iterable):
197
+ if not isinstance(callback_result, Iterable):
149
198
  callback_result = [callback_result]
150
199
 
151
200
  for callback in callback_result:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.19.0
3
+ Version: 0.19.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -3,7 +3,7 @@ nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjX
3
3
  nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
4
4
  nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
5
5
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
6
- nshtrainer/_hf_hub.py,sha256=Py9_8ADvMCFPaJzeE7bxm8Mgs3mEMkyWJ4pDEccTGt8,11230
6
+ nshtrainer/_hf_hub.py,sha256=To3BnnGWbMNNMBdzVtgrNOcNU2fi1dQpwwuclusFAbI,12169
7
7
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
8
8
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
9
9
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
@@ -58,7 +58,7 @@ nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ
58
58
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
59
59
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
60
60
  nshtrainer/model/config.py,sha256=147uV7IukvuYE4G_ZuQNxVjnlog1BdCrAVbcj_sx9Vs,43104
61
- nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
61
+ nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
62
62
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
63
63
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
64
64
  nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
85
85
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
86
86
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
87
87
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
88
- nshtrainer-0.19.0.dist-info/METADATA,sha256=VLb38BSORQBx6g_SfGnbdBWa37N9xCtZ-JI45ATouzY,935
89
- nshtrainer-0.19.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
- nshtrainer-0.19.0.dist-info/RECORD,,
88
+ nshtrainer-0.19.2.dist-info/METADATA,sha256=InNVoRQEPpPRCFbBje-ekgQzFFycxC9VzQsmEqUJK1c,935
89
+ nshtrainer-0.19.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
+ nshtrainer-0.19.2.dist-info/RECORD,,