nshtrainer 0.19.0__py3-none-any.whl → 0.19.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import re
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
6
7
|
|
|
@@ -150,7 +151,32 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
|
|
|
150
151
|
elif (username := api.whoami().get("name", None)) is None:
|
|
151
152
|
raise ValueError("Could not get username from Hugging Face Hub.")
|
|
152
153
|
|
|
153
|
-
|
|
154
|
+
# Sanitize the project (if it exists), run_name, and id
|
|
155
|
+
parts = []
|
|
156
|
+
if root_config.project:
|
|
157
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
|
|
158
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
|
|
159
|
+
parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
|
|
160
|
+
|
|
161
|
+
# Combine parts and ensure it starts and ends with alphanumeric characters
|
|
162
|
+
repo_name = "-".join(parts)
|
|
163
|
+
repo_name = repo_name.strip("-")
|
|
164
|
+
repo_name = re.sub(
|
|
165
|
+
r"-+", "-", repo_name
|
|
166
|
+
) # Replace multiple dashes with a single dash
|
|
167
|
+
|
|
168
|
+
# Ensure the name is not longer than 96 characters (excluding username)
|
|
169
|
+
if len(repo_name) > 96:
|
|
170
|
+
repo_name = repo_name[:96].rstrip("-")
|
|
171
|
+
|
|
172
|
+
# Ensure the repo name starts with an alphanumeric character
|
|
173
|
+
repo_name = re.sub(r"^[^a-zA-Z0-9]+", "", repo_name)
|
|
174
|
+
|
|
175
|
+
# If the repo_name is empty after all sanitization, use a default name
|
|
176
|
+
if not repo_name:
|
|
177
|
+
repo_name = "default-repo-name"
|
|
178
|
+
|
|
179
|
+
return f"{username}/{repo_name}"
|
|
154
180
|
|
|
155
181
|
|
|
156
182
|
def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from collections import
|
|
3
|
-
from
|
|
4
|
-
from typing import Any, TypeAlias, cast, final
|
|
2
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
3
|
+
from typing import Any, TypeAlias, cast, final, overload
|
|
5
4
|
|
|
6
5
|
from lightning.pytorch import Callback, LightningModule
|
|
7
6
|
from lightning.pytorch.callbacks import LambdaCallback
|
|
@@ -19,11 +18,61 @@ class CallbackRegistrarModuleMixin:
|
|
|
19
18
|
def __init__(self, *args, **kwargs):
|
|
20
19
|
super().__init__(*args, **kwargs)
|
|
21
20
|
|
|
22
|
-
self.
|
|
21
|
+
self._nshtrainer_callbacks: list[CallbackFn] = []
|
|
22
|
+
|
|
23
|
+
@overload
|
|
24
|
+
def register_callback(
|
|
25
|
+
self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
|
|
26
|
+
): ...
|
|
27
|
+
|
|
28
|
+
@overload
|
|
29
|
+
def register_callback(
|
|
30
|
+
self,
|
|
31
|
+
/,
|
|
32
|
+
*,
|
|
33
|
+
setup: Callable | None = None,
|
|
34
|
+
teardown: Callable | None = None,
|
|
35
|
+
on_fit_start: Callable | None = None,
|
|
36
|
+
on_fit_end: Callable | None = None,
|
|
37
|
+
on_sanity_check_start: Callable | None = None,
|
|
38
|
+
on_sanity_check_end: Callable | None = None,
|
|
39
|
+
on_train_batch_start: Callable | None = None,
|
|
40
|
+
on_train_batch_end: Callable | None = None,
|
|
41
|
+
on_train_epoch_start: Callable | None = None,
|
|
42
|
+
on_train_epoch_end: Callable | None = None,
|
|
43
|
+
on_validation_epoch_start: Callable | None = None,
|
|
44
|
+
on_validation_epoch_end: Callable | None = None,
|
|
45
|
+
on_test_epoch_start: Callable | None = None,
|
|
46
|
+
on_test_epoch_end: Callable | None = None,
|
|
47
|
+
on_validation_batch_start: Callable | None = None,
|
|
48
|
+
on_validation_batch_end: Callable | None = None,
|
|
49
|
+
on_test_batch_start: Callable | None = None,
|
|
50
|
+
on_test_batch_end: Callable | None = None,
|
|
51
|
+
on_train_start: Callable | None = None,
|
|
52
|
+
on_train_end: Callable | None = None,
|
|
53
|
+
on_validation_start: Callable | None = None,
|
|
54
|
+
on_validation_end: Callable | None = None,
|
|
55
|
+
on_test_start: Callable | None = None,
|
|
56
|
+
on_test_end: Callable | None = None,
|
|
57
|
+
on_exception: Callable | None = None,
|
|
58
|
+
on_save_checkpoint: Callable | None = None,
|
|
59
|
+
on_load_checkpoint: Callable | None = None,
|
|
60
|
+
on_before_backward: Callable | None = None,
|
|
61
|
+
on_after_backward: Callable | None = None,
|
|
62
|
+
on_before_optimizer_step: Callable | None = None,
|
|
63
|
+
on_before_zero_grad: Callable | None = None,
|
|
64
|
+
on_predict_start: Callable | None = None,
|
|
65
|
+
on_predict_end: Callable | None = None,
|
|
66
|
+
on_predict_batch_start: Callable | None = None,
|
|
67
|
+
on_predict_batch_end: Callable | None = None,
|
|
68
|
+
on_predict_epoch_start: Callable | None = None,
|
|
69
|
+
on_predict_epoch_end: Callable | None = None,
|
|
70
|
+
): ...
|
|
23
71
|
|
|
24
72
|
def register_callback(
|
|
25
73
|
self,
|
|
26
74
|
callback: Callback | Iterable[Callback] | CallbackFn | None = None,
|
|
75
|
+
/,
|
|
27
76
|
*,
|
|
28
77
|
setup: Callable | None = None,
|
|
29
78
|
teardown: Callable | None = None,
|
|
@@ -109,7 +158,7 @@ class CallbackRegistrarModuleMixin:
|
|
|
109
158
|
else:
|
|
110
159
|
callback_ = callback
|
|
111
160
|
|
|
112
|
-
self.
|
|
161
|
+
self._nshtrainer_callbacks.append(callback_)
|
|
113
162
|
|
|
114
163
|
|
|
115
164
|
class CallbackModuleMixin(
|
|
@@ -136,7 +185,7 @@ class CallbackModuleMixin(
|
|
|
136
185
|
@override
|
|
137
186
|
def configure_callbacks(self):
|
|
138
187
|
callbacks = super().configure_callbacks()
|
|
139
|
-
if not isinstance(callbacks,
|
|
188
|
+
if not isinstance(callbacks, Sequence):
|
|
140
189
|
callbacks = [callbacks]
|
|
141
190
|
|
|
142
191
|
callbacks = list(callbacks)
|
|
@@ -145,7 +194,7 @@ class CallbackModuleMixin(
|
|
|
145
194
|
if callback_result is None:
|
|
146
195
|
continue
|
|
147
196
|
|
|
148
|
-
if not isinstance(callback_result,
|
|
197
|
+
if not isinstance(callback_result, Iterable):
|
|
149
198
|
callback_result = [callback_result]
|
|
150
199
|
|
|
151
200
|
for callback in callback_result:
|
|
@@ -3,7 +3,7 @@ nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjX
|
|
|
3
3
|
nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
|
|
4
4
|
nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
6
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
6
|
+
nshtrainer/_hf_hub.py,sha256=To3BnnGWbMNNMBdzVtgrNOcNU2fi1dQpwwuclusFAbI,12169
|
|
7
7
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
8
8
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
9
9
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
@@ -58,7 +58,7 @@ nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ
|
|
|
58
58
|
nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
|
|
59
59
|
nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
|
|
60
60
|
nshtrainer/model/config.py,sha256=147uV7IukvuYE4G_ZuQNxVjnlog1BdCrAVbcj_sx9Vs,43104
|
|
61
|
-
nshtrainer/model/modules/callback.py,sha256=
|
|
61
|
+
nshtrainer/model/modules/callback.py,sha256=thhlJaqLRw2gwvb3Z6DJ8Kk8XUxKhinU_8ad30vne34,8541
|
|
62
62
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
63
63
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
64
64
|
nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
|
|
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
85
85
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
86
86
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
87
87
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
88
|
-
nshtrainer-0.19.
|
|
89
|
-
nshtrainer-0.19.
|
|
90
|
-
nshtrainer-0.19.
|
|
88
|
+
nshtrainer-0.19.1.dist-info/METADATA,sha256=NMPSdeNqcMnyB9UiQ-4f-MdhBZ_RmCAPCYcYCCvjyYI,935
|
|
89
|
+
nshtrainer-0.19.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
90
|
+
nshtrainer-0.19.1.dist-info/RECORD,,
|
|
File without changes
|