nshtrainer 0.21.0__tar.gz → 0.22.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/PKG-INFO +1 -2
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/pyproject.toml +1 -3
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_checkpoint/metadata.py +2 -1
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_hf_hub.py +2 -9
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/base.py +22 -19
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/trainer.py +2 -1
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/README.md +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.22.1
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -26,7 +26,6 @@ Requires-Dist: torchmetrics ; extra == "extra"
|
|
|
26
26
|
Requires-Dist: typing-extensions
|
|
27
27
|
Requires-Dist: wandb ; extra == "extra"
|
|
28
28
|
Requires-Dist: wrapt ; extra == "extra"
|
|
29
|
-
Requires-Dist: zstandard ; extra == "extra"
|
|
30
29
|
Description-Content-Type: text/markdown
|
|
31
30
|
|
|
32
31
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.22.1"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -23,7 +23,6 @@ GitPython = { version = "*", optional = true }
|
|
|
23
23
|
wandb = { version = "*", optional = true }
|
|
24
24
|
tensorboard = { version = "*", optional = true }
|
|
25
25
|
huggingface-hub = { version = "*", optional = true }
|
|
26
|
-
zstandard = { version = "*", optional = true }
|
|
27
26
|
|
|
28
27
|
[tool.poetry.group.dev.dependencies]
|
|
29
28
|
pyright = "^1.1.372"
|
|
@@ -54,5 +53,4 @@ extra = [
|
|
|
54
53
|
"wandb",
|
|
55
54
|
"tensorboard",
|
|
56
55
|
"huggingface-hub",
|
|
57
|
-
"zstandard",
|
|
58
56
|
]
|
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ..util._environment_info import EnvironmentConfig
|
|
14
|
+
from ..util.path import get_relative_path
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from ..model import BaseConfig, LightningModuleBase
|
|
@@ -145,7 +146,7 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
145
146
|
# We should store the path as a relative path
|
|
146
147
|
# to the metadata file to avoid issues with
|
|
147
148
|
# moving the checkpoint directory
|
|
148
|
-
linked_path.symlink_to(
|
|
149
|
+
linked_path.symlink_to(get_relative_path(linked_path, path))
|
|
149
150
|
except OSError:
|
|
150
151
|
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
151
152
|
# fall back to copying the file
|
|
@@ -10,11 +10,7 @@ from nshrunner._env import SNAPSHOT_DIR
|
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
12
12
|
from ._callback import NTCallbackBase
|
|
13
|
-
from .callbacks.base import
|
|
14
|
-
CallbackConfigBase,
|
|
15
|
-
CallbackMetadataConfig,
|
|
16
|
-
CallbackWithMetadata,
|
|
17
|
-
)
|
|
13
|
+
from .callbacks.base import CallbackConfigBase
|
|
18
14
|
|
|
19
15
|
if TYPE_CHECKING:
|
|
20
16
|
from huggingface_hub import HfApi # noqa: F401
|
|
@@ -81,10 +77,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
|
|
|
81
77
|
|
|
82
78
|
@override
|
|
83
79
|
def create_callbacks(self, root_config):
|
|
84
|
-
yield
|
|
85
|
-
HFHubCallback(self),
|
|
86
|
-
CallbackMetadataConfig(ignore_if_exists=True),
|
|
87
|
-
)
|
|
80
|
+
yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
|
|
88
81
|
|
|
89
82
|
|
|
90
83
|
def _api(token: str | None = None):
|
|
@@ -2,29 +2,24 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from collections import Counter
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import TYPE_CHECKING,
|
|
5
|
+
from typing import TYPE_CHECKING, ClassVar, TypeAlias
|
|
6
6
|
|
|
7
7
|
import nshconfig as C
|
|
8
8
|
from lightning.pytorch import Callback
|
|
9
|
+
from typing_extensions import TypedDict, Unpack
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
11
12
|
from ..model.config import BaseConfig
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
class
|
|
15
|
+
class CallbackMetadataConfig(TypedDict, total=False):
|
|
15
16
|
ignore_if_exists: bool
|
|
16
|
-
"""If `True`, the callback will not be added if another callback with the same class already exists.
|
|
17
|
+
"""If `True`, the callback will not be added if another callback with the same class already exists.
|
|
18
|
+
Default is `False`."""
|
|
17
19
|
|
|
18
20
|
priority: int
|
|
19
|
-
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class CallbackMetadataConfig(C.Config):
|
|
23
|
-
ignore_if_exists: bool = False
|
|
24
|
-
"""If `True`, the callback will not be added if another callback with the same class already exists."""
|
|
25
|
-
|
|
26
|
-
priority: int = 0
|
|
27
|
-
"""Priority of the callback. Callbacks with higher priority will be loaded first."""
|
|
21
|
+
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
|
22
|
+
Default is `0`."""
|
|
28
23
|
|
|
29
24
|
|
|
30
25
|
@dataclass(frozen=True)
|
|
@@ -37,13 +32,18 @@ ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
|
|
|
37
32
|
|
|
38
33
|
|
|
39
34
|
class CallbackConfigBase(C.Config, ABC):
|
|
40
|
-
metadata: CallbackMetadataConfig = CallbackMetadataConfig()
|
|
35
|
+
metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig()
|
|
41
36
|
"""Metadata for the callback."""
|
|
42
37
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
38
|
+
@classmethod
|
|
39
|
+
def with_metadata(
|
|
40
|
+
cls, callback: Callback, **kwargs: Unpack[CallbackMetadataConfig]
|
|
41
|
+
):
|
|
42
|
+
metadata: CallbackMetadataConfig = {}
|
|
43
|
+
metadata.update(cls.metadata)
|
|
44
|
+
metadata.update(kwargs)
|
|
45
|
+
|
|
46
|
+
return CallbackWithMetadata(callback=callback, metadata=metadata)
|
|
47
47
|
|
|
48
48
|
@abstractmethod
|
|
49
49
|
def create_callbacks(
|
|
@@ -73,7 +73,7 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
|
|
|
73
73
|
for callback in callbacks:
|
|
74
74
|
# If `ignore_if_exists` is `True` and there is already a callback of the same class, skip this callback
|
|
75
75
|
if (
|
|
76
|
-
callback.metadata.ignore_if_exists
|
|
76
|
+
callback.metadata.get("ignore_if_exists", False)
|
|
77
77
|
and callback_classes[callback.callback.__class__] > 1
|
|
78
78
|
):
|
|
79
79
|
continue
|
|
@@ -89,7 +89,10 @@ def _process_and_filter_callbacks(
|
|
|
89
89
|
callbacks = list(callbacks)
|
|
90
90
|
|
|
91
91
|
# Sort by priority (higher priority first)
|
|
92
|
-
callbacks.sort(
|
|
92
|
+
callbacks.sort(
|
|
93
|
+
key=lambda callback: callback.metadata.get("priority", 0),
|
|
94
|
+
reverse=True,
|
|
95
|
+
)
|
|
93
96
|
|
|
94
97
|
# Process `ignore_if_exists`
|
|
95
98
|
callbacks = _filter_ignore_if_exists(callbacks)
|
|
@@ -439,7 +439,8 @@ class Trainer(LightningTrainer):
|
|
|
439
439
|
):
|
|
440
440
|
# If we have a cached path, then we symlink it to the new path.
|
|
441
441
|
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
442
|
-
|
|
442
|
+
if self.is_global_zero:
|
|
443
|
+
_link_checkpoint(cached_path, filepath, metadata=False)
|
|
443
444
|
else:
|
|
444
445
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
445
446
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
File without changes
|
{nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|