nshtrainer 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py
CHANGED
|
@@ -10,11 +10,7 @@ from nshrunner._env import SNAPSHOT_DIR
|
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
12
12
|
from ._callback import NTCallbackBase
|
|
13
|
-
from .callbacks.base import
|
|
14
|
-
CallbackConfigBase,
|
|
15
|
-
CallbackMetadataConfig,
|
|
16
|
-
CallbackWithMetadata,
|
|
17
|
-
)
|
|
13
|
+
from .callbacks.base import CallbackConfigBase
|
|
18
14
|
|
|
19
15
|
if TYPE_CHECKING:
|
|
20
16
|
from huggingface_hub import HfApi # noqa: F401
|
|
@@ -81,10 +77,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
|
|
|
81
77
|
|
|
82
78
|
@override
|
|
83
79
|
def create_callbacks(self, root_config):
|
|
84
|
-
yield
|
|
85
|
-
HFHubCallback(self),
|
|
86
|
-
CallbackMetadataConfig(ignore_if_exists=True),
|
|
87
|
-
)
|
|
80
|
+
yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
|
|
88
81
|
|
|
89
82
|
|
|
90
83
|
def _api(token: str | None = None):
|
nshtrainer/callbacks/base.py
CHANGED
|
@@ -2,29 +2,24 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from collections import Counter
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import TYPE_CHECKING,
|
|
5
|
+
from typing import TYPE_CHECKING, ClassVar, TypeAlias
|
|
6
6
|
|
|
7
7
|
import nshconfig as C
|
|
8
8
|
from lightning.pytorch import Callback
|
|
9
|
+
from typing_extensions import TypedDict, Unpack
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
11
12
|
from ..model.config import BaseConfig
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
class
|
|
15
|
+
class CallbackMetadataConfig(TypedDict, total=False):
|
|
15
16
|
ignore_if_exists: bool
|
|
16
|
-
"""If `True`, the callback will not be added if another callback with the same class already exists.
|
|
17
|
+
"""If `True`, the callback will not be added if another callback with the same class already exists.
|
|
18
|
+
Default is `False`."""
|
|
17
19
|
|
|
18
20
|
priority: int
|
|
19
|
-
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class CallbackMetadataConfig(C.Config):
|
|
23
|
-
ignore_if_exists: bool = False
|
|
24
|
-
"""If `True`, the callback will not be added if another callback with the same class already exists."""
|
|
25
|
-
|
|
26
|
-
priority: int = 0
|
|
27
|
-
"""Priority of the callback. Callbacks with higher priority will be loaded first."""
|
|
21
|
+
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
|
22
|
+
Default is `0`."""
|
|
28
23
|
|
|
29
24
|
|
|
30
25
|
@dataclass(frozen=True)
|
|
@@ -37,13 +32,18 @@ ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
|
|
|
37
32
|
|
|
38
33
|
|
|
39
34
|
class CallbackConfigBase(C.Config, ABC):
|
|
40
|
-
metadata: CallbackMetadataConfig = CallbackMetadataConfig()
|
|
35
|
+
metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig()
|
|
41
36
|
"""Metadata for the callback."""
|
|
42
37
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
38
|
+
@classmethod
|
|
39
|
+
def with_metadata(
|
|
40
|
+
cls, callback: Callback, **kwargs: Unpack[CallbackMetadataConfig]
|
|
41
|
+
):
|
|
42
|
+
metadata: CallbackMetadataConfig = {}
|
|
43
|
+
metadata.update(cls.metadata)
|
|
44
|
+
metadata.update(kwargs)
|
|
45
|
+
|
|
46
|
+
return CallbackWithMetadata(callback=callback, metadata=metadata)
|
|
47
47
|
|
|
48
48
|
@abstractmethod
|
|
49
49
|
def create_callbacks(
|
|
@@ -73,7 +73,7 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
|
|
|
73
73
|
for callback in callbacks:
|
|
74
74
|
# If `ignore_if_exists` is `True` and there is already a callback of the same class, skip this callback
|
|
75
75
|
if (
|
|
76
|
-
callback.metadata.ignore_if_exists
|
|
76
|
+
callback.metadata.get("ignore_if_exists", False)
|
|
77
77
|
and callback_classes[callback.callback.__class__] > 1
|
|
78
78
|
):
|
|
79
79
|
continue
|
|
@@ -89,7 +89,10 @@ def _process_and_filter_callbacks(
|
|
|
89
89
|
callbacks = list(callbacks)
|
|
90
90
|
|
|
91
91
|
# Sort by priority (higher priority first)
|
|
92
|
-
callbacks.sort(
|
|
92
|
+
callbacks.sort(
|
|
93
|
+
key=lambda callback: callback.metadata.get("priority", 0),
|
|
94
|
+
reverse=True,
|
|
95
|
+
)
|
|
93
96
|
|
|
94
97
|
# Process `ignore_if_exists`
|
|
95
98
|
callbacks = _filter_ignore_if_exists(callbacks)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.22.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -26,7 +26,6 @@ Requires-Dist: torchmetrics ; extra == "extra"
|
|
|
26
26
|
Requires-Dist: typing-extensions
|
|
27
27
|
Requires-Dist: wandb ; extra == "extra"
|
|
28
28
|
Requires-Dist: wrapt ; extra == "extra"
|
|
29
|
-
Requires-Dist: zstandard ; extra == "extra"
|
|
30
29
|
Description-Content-Type: text/markdown
|
|
31
30
|
|
|
32
31
|
|
|
@@ -4,11 +4,11 @@ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uP
|
|
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=TLAt7yR3KhSRbXCtomLMxcMvOiAju873A1ZRo8VWNwA,5179
|
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
|
|
6
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
7
|
+
nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
|
|
8
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
9
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
|
-
nshtrainer/callbacks/base.py,sha256=
|
|
11
|
+
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
|
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
87
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
88
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
89
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
90
|
-
nshtrainer-0.
|
|
91
|
-
nshtrainer-0.
|
|
92
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.22.0.dist-info/METADATA,sha256=sdjt9S4X3xiIGgD6FNF06yIyC1tJA89B9Qm9mxy29tc,935
|
|
91
|
+
nshtrainer-0.22.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.22.0.dist-info/RECORD,,
|
|
File without changes
|