nshtrainer 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -10,11 +10,7 @@ from nshrunner._env import SNAPSHOT_DIR
10
10
  from typing_extensions import override
11
11
 
12
12
  from ._callback import NTCallbackBase
13
- from .callbacks.base import (
14
- CallbackConfigBase,
15
- CallbackMetadataConfig,
16
- CallbackWithMetadata,
17
- )
13
+ from .callbacks.base import CallbackConfigBase
18
14
 
19
15
  if TYPE_CHECKING:
20
16
  from huggingface_hub import HfApi # noqa: F401
@@ -81,10 +77,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
81
77
 
82
78
  @override
83
79
  def create_callbacks(self, root_config):
84
- yield CallbackWithMetadata(
85
- HFHubCallback(self),
86
- CallbackMetadataConfig(ignore_if_exists=True),
87
- )
80
+ yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
88
81
 
89
82
 
90
83
  def _api(token: str | None = None):
@@ -2,29 +2,24 @@ from abc import ABC, abstractmethod
2
2
  from collections import Counter
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, TypeAlias, TypedDict
5
+ from typing import TYPE_CHECKING, ClassVar, TypeAlias
6
6
 
7
7
  import nshconfig as C
8
8
  from lightning.pytorch import Callback
9
+ from typing_extensions import TypedDict, Unpack
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from ..model.config import BaseConfig
12
13
 
13
14
 
14
- class CallbackMetadataDict(TypedDict, total=False):
15
+ class CallbackMetadataConfig(TypedDict, total=False):
15
16
  ignore_if_exists: bool
16
- """If `True`, the callback will not be added if another callback with the same class already exists."""
17
+ """If `True`, the callback will not be added if another callback with the same class already exists.
18
+ Default is `False`."""
17
19
 
18
20
  priority: int
19
- """Priority of the callback. Callbacks with higher priority will be loaded first."""
20
-
21
-
22
- class CallbackMetadataConfig(C.Config):
23
- ignore_if_exists: bool = False
24
- """If `True`, the callback will not be added if another callback with the same class already exists."""
25
-
26
- priority: int = 0
27
- """Priority of the callback. Callbacks with higher priority will be loaded first."""
21
+ """Priority of the callback. Callbacks with higher priority will be loaded first.
22
+ Default is `0`."""
28
23
 
29
24
 
30
25
  @dataclass(frozen=True)
@@ -37,13 +32,18 @@ ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
37
32
 
38
33
 
39
34
  class CallbackConfigBase(C.Config, ABC):
40
- metadata: CallbackMetadataConfig = CallbackMetadataConfig()
35
+ metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig()
41
36
  """Metadata for the callback."""
42
37
 
43
- def with_metadata(self, callback: Callback, **metadata: CallbackMetadataDict):
44
- return CallbackWithMetadata(
45
- callback=callback, metadata=self.metadata.model_copy(update=metadata)
46
- )
38
+ @classmethod
39
+ def with_metadata(
40
+ cls, callback: Callback, **kwargs: Unpack[CallbackMetadataConfig]
41
+ ):
42
+ metadata: CallbackMetadataConfig = {}
43
+ metadata.update(cls.metadata)
44
+ metadata.update(kwargs)
45
+
46
+ return CallbackWithMetadata(callback=callback, metadata=metadata)
47
47
 
48
48
  @abstractmethod
49
49
  def create_callbacks(
@@ -73,7 +73,7 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
73
73
  for callback in callbacks:
74
74
  # If `ignore_if_exists` is `True` and there is already a callback of the same class, skip this callback
75
75
  if (
76
- callback.metadata.ignore_if_exists
76
+ callback.metadata.get("ignore_if_exists", False)
77
77
  and callback_classes[callback.callback.__class__] > 1
78
78
  ):
79
79
  continue
@@ -89,7 +89,10 @@ def _process_and_filter_callbacks(
89
89
  callbacks = list(callbacks)
90
90
 
91
91
  # Sort by priority (higher priority first)
92
- callbacks.sort(key=lambda callback: callback.metadata.priority, reverse=True)
92
+ callbacks.sort(
93
+ key=lambda callback: callback.metadata.get("priority", 0),
94
+ reverse=True,
95
+ )
93
96
 
94
97
  # Process `ignore_if_exists`
95
98
  callbacks = _filter_ignore_if_exists(callbacks)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.21.0
3
+ Version: 0.22.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -26,7 +26,6 @@ Requires-Dist: torchmetrics ; extra == "extra"
26
26
  Requires-Dist: typing-extensions
27
27
  Requires-Dist: wandb ; extra == "extra"
28
28
  Requires-Dist: wrapt ; extra == "extra"
29
- Requires-Dist: zstandard ; extra == "extra"
30
29
  Description-Content-Type: text/markdown
31
30
 
32
31
 
@@ -4,11 +4,11 @@ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uP
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=TLAt7yR3KhSRbXCtomLMxcMvOiAju873A1ZRo8VWNwA,5179
5
5
  nshtrainer/_checkpoint/saver.py,sha256=6W-Rbc3QGuhcF_mcwN8v31uEjLQCsZvt8CPuqPs4m5g,1342
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=0bOhJNyIjQGJsMRaW7qQJc1oTnUMHj08auuztzTQvZ0,16906
7
+ nshtrainer/_hf_hub.py,sha256=iqhXH54RhSqmot_K3UCVcHTC_TC81_YY7cwvHGHXXlw,16782
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
- nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
11
+ nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
13
  nshtrainer/callbacks/checkpoint/_base.py,sha256=r6IPpl3sGUmxBNv80y9r326lTrPAIVSU3Fu-3LrYH2s,6691
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.21.0.dist-info/METADATA,sha256=7QfSX_yXi-Up6uxOVFfDPn4ieGK5b3UgQfO_KFsNzXk,979
91
- nshtrainer-0.21.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.21.0.dist-info/RECORD,,
90
+ nshtrainer-0.22.0.dist-info/METADATA,sha256=sdjt9S4X3xiIGgD6FNF06yIyC1tJA89B9Qm9mxy29tc,935
91
+ nshtrainer-0.22.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.22.0.dist-info/RECORD,,