nshtrainer 0.10.9__py3-none-any.whl → 0.10.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,8 @@
1
1
  import copy
2
2
  import datetime
3
3
  import logging
4
+ import shutil
5
+ from collections.abc import Callable
4
6
  from pathlib import Path
5
7
  from typing import TYPE_CHECKING, Any, cast
6
8
 
@@ -100,3 +102,74 @@ def _write_checkpoint_metadata(
100
102
  log.warning(f"Failed to write hparams to {checkpoint_path}: {e}")
101
103
  else:
102
104
  log.info(f"Checkpoint metadata written to {checkpoint_path}")
105
+
106
+
107
+ def _remove_checkpoint_metadata(checkpoint_path: Path):
108
+ for path in (
109
+ checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
110
+ checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
111
+ ):
112
+ try:
113
+ path.unlink(missing_ok=True)
114
+ except Exception as e:
115
+ log.warning(f"Failed to remove {path}: {e}")
116
+ else:
117
+ log.info(f"Removed {path}")
118
+
119
+
120
+ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
121
+ # First, remove any existing metadata files
122
+ _remove_checkpoint_metadata(linked_checkpoint_path)
123
+
124
+ # Link the metadata files to the new checkpoint
125
+ for path in (
126
+ checkpoint_path.with_suffix(METADATA_PATH_SUFFIX),
127
+ checkpoint_path.with_suffix(HPARAMS_PATH_SUFFIX),
128
+ ):
129
+ linked_path = linked_checkpoint_path.with_suffix(path.suffix)
130
+ try:
131
+ try:
132
+ linked_path.symlink_to(path)
133
+ except OSError:
134
+ # on Windows, special permissions are required to create symbolic links as a regular user
135
+ # fall back to copying the file
136
+ shutil.copy(path, linked_path)
137
+ except Exception as e:
138
+ log.warning(f"Failed to link {path} to {linked_path}: {e}")
139
+ else:
140
+ log.info(f"Linked {path} to {linked_path}")
141
+
142
+
143
+ def _checkpoint_sort_key_fn(key: Callable[[CheckpointMetadata, Path], Any]):
144
+ def sort_key_fn(checkpoint_path: Path):
145
+ if not (p := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)).exists():
146
+ raise FileNotFoundError(f"Metadata file not found: {p}")
147
+
148
+ nonlocal key
149
+ return key(CheckpointMetadata.from_file(p), p)
150
+
151
+ return sort_key_fn
152
+
153
+
154
+ def _sort_ckpts_by_metadata(
155
+ checkpoint_paths: list[Path],
156
+ key: Callable[[CheckpointMetadata, Path], Any],
157
+ fallback_key: Callable[[Path], Any],
158
+ ):
159
+ # First, let's make sure all the metadata files exist.
160
+ # If not, use the fallback function to sort the checkpoints.
161
+ no_metadata_paths: list[Path] = []
162
+ for path in checkpoint_paths:
163
+ if (path.with_suffix(METADATA_PATH_SUFFIX)).exists():
164
+ continue
165
+
166
+ no_metadata_paths.append(path)
167
+
168
+ if no_metadata_paths:
169
+ log.warning(
170
+ f"Metadata file not found on {len(no_metadata_paths)} checkpoints: {no_metadata_paths}\n"
171
+ "Falling back to sorting by last modified time."
172
+ )
173
+ return sorted(checkpoint_paths, key=fallback_key)
174
+
175
+ return sorted(checkpoint_paths, key=_checkpoint_sort_key_fn(key))
@@ -0,0 +1,52 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ from lightning.pytorch import Trainer
6
+
7
+ from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
+
9
+
10
+ def _link_checkpoint(
11
+ trainer: Trainer,
12
+ filepath: str | Path | os.PathLike,
13
+ linkpath: str | Path | os.PathLike,
14
+ *,
15
+ barrier: bool,
16
+ metadata: bool,
17
+ ):
18
+ if not isinstance(filepath, Path):
19
+ filepath = Path(filepath)
20
+ if not isinstance(linkpath, Path):
21
+ linkpath = Path(linkpath)
22
+
23
+ if trainer.is_global_zero:
24
+ if linkpath.exists():
25
+ if linkpath.is_symlink() or linkpath.is_file():
26
+ linkpath.unlink()
27
+ elif linkpath.is_dir():
28
+ shutil.rmtree(linkpath)
29
+ _remove_checkpoint_metadata(linkpath)
30
+
31
+ try:
32
+ target_path = filepath.relative_to(linkpath.parent)
33
+ linkpath.symlink_to(target_path)
34
+ except OSError:
35
+ # on Windows, special permissions are required to create symbolic links as a regular user
36
+ # fall back to copying the file
37
+ shutil.copy(filepath, linkpath)
38
+
39
+ _link_checkpoint_metadata(filepath, linkpath)
40
+ if barrier:
41
+ trainer.strategy.barrier()
42
+
43
+
44
+ def _remove_checkpoint(
45
+ trainer: Trainer,
46
+ filepath: str | Path | os.PathLike,
47
+ remove_metadata: bool = True,
48
+ ):
49
+ if not isinstance(filepath, Path):
50
+ filepath = Path(filepath)
51
+ trainer.strategy.remove_checkpoint(filepath)
52
+ _remove_checkpoint_metadata(filepath)
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  import math
2
- from logging import getLogger
3
3
 
4
4
  from lightning.fabric.utilities.rank_zero import _get_rank
5
5
  from lightning.pytorch import Trainer
@@ -7,7 +7,7 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
7
7
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
8
8
  from typing_extensions import override
9
9
 
10
- log = getLogger(__name__)
10
+ log = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  class EarlyStopping(_EarlyStopping):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal
3
3
 
4
4
  import torch
@@ -7,7 +7,7 @@ from typing_extensions import override
7
7
 
8
8
  from .base import CallbackConfigBase
9
9
 
10
- log = getLogger(__name__)
10
+ log = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  def finite_checks(
@@ -6,6 +6,8 @@ from lightning.pytorch import LightningModule, Trainer
6
6
  from lightning.pytorch.callbacks import Checkpoint
7
7
  from typing_extensions import override
8
8
 
9
+ from .._checkpoint.metadata import _sort_ckpts_by_metadata
10
+ from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
9
11
  from .base import CallbackConfigBase
10
12
 
11
13
  log = logging.getLogger(__name__)
@@ -17,15 +19,18 @@ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
17
19
  dirpath: str | Path | None = None
18
20
  """Directory path to save the checkpoint file."""
19
21
 
20
- filename: str = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
22
+ filename: str = "epoch{epoch:02d}_step{step:04d}"
21
23
  """Checkpoint filename. This must not include the extension."""
22
24
 
23
25
  save_weights_only: bool = False
24
26
  """Whether to save only the model's weights or the entire model object."""
25
27
 
26
- latest_symlink_filename: str | None = "latest.ckpt"
28
+ latest_symlink_filename: str | None = "latest"
27
29
  """Filename for the latest symlink. If None, no symlink will be created."""
28
30
 
31
+ latest_k: int | Literal["all"] = 1
32
+ """Number of latest checkpoints to keep. If "all", all checkpoints are kept."""
33
+
29
34
  @override
30
35
  def create_callbacks(self, root_config):
31
36
  dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
@@ -37,38 +42,73 @@ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
37
42
 
38
43
 
39
44
  class LatestEpochCheckpoint(Checkpoint):
45
+ PREFIX = "latest_"
46
+ EXTENSION = ".ckpt"
47
+
40
48
  def __init__(self, config: LatestEpochCheckpointCallbackConfig, dirpath: Path):
41
49
  super().__init__()
42
50
 
43
51
  self.config = config
44
52
  self.dirpath = dirpath
45
53
 
54
+ @override
55
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
56
+ self._save_new_checkpoint(trainer)
57
+
58
+ def _latest_symlink_filename(self):
59
+ if (filename := self.config.latest_symlink_filename) is None:
60
+ return None
61
+ return f"{filename}{self.EXTENSION}"
62
+
46
63
  def _ckpt_path(self, trainer: Trainer):
47
- return self.dirpath / self.config.filename.format(
64
+ filename = self.config.filename.format(
48
65
  epoch=trainer.current_epoch, step=trainer.global_step
49
66
  )
67
+ filename = f"{self.PREFIX}{filename}.{self.EXTENSION}"
68
+ return self.dirpath / filename
69
+
70
+ def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
71
+ for ckpt_path in ckpt_paths:
72
+ _remove_checkpoint(trainer, ckpt_path, remove_metadata=True)
73
+
74
+ def _remove_old_checkpoints(self, trainer: Trainer):
75
+ if (latest_k := self.config.latest_k) == "all":
76
+ return
77
+
78
+ # Get all configs, ignoring the latest symlink
79
+ ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
80
+ # Ignore the latest symlink
81
+ if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
82
+ ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
83
+
84
+ # Sort by epoch, then step, then last modified
85
+ ckpt_paths = _sort_ckpts_by_metadata(
86
+ ckpt_paths,
87
+ key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
88
+ fallback_key=lambda p: p.stat().st_mtime,
89
+ # ^ Called if metadata is not found on all checkpoints
90
+ )
91
+
92
+ # Remove all but the latest k checkpoints
93
+ ckpts_to_remove = ckpt_paths[:-latest_k]
94
+ self._remove_checkpoints(trainer, ckpts_to_remove)
95
+
96
+ def _save_new_checkpoint(self, trainer: Trainer):
97
+ # Remove old checkpoints
98
+ self._remove_old_checkpoints(trainer)
50
99
 
51
- @override
52
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
53
100
  # Save the new checkpoint
54
101
  filepath = self._ckpt_path(trainer)
55
102
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
56
103
 
57
104
  # Create the latest symlink
58
- if (
59
- trainer.is_global_zero
60
- and (symlink_filename := self.config.latest_symlink_filename) is not None
61
- ):
105
+ if (symlink_filename := self._latest_symlink_filename()) is not None:
62
106
  symlink_path = self.dirpath / symlink_filename
63
- symlink_path.unlink(missing_ok=True)
64
- symlink_path.symlink_to(filepath.name)
107
+ _link_checkpoint(
108
+ trainer,
109
+ filepath,
110
+ symlink_path,
111
+ barrier=True,
112
+ metadata=True,
113
+ )
65
114
  log.info(f"Created latest symlink: {symlink_path}")
66
-
67
- def latest_checkpoint(self):
68
- if (symlink_filename := self.config.latest_symlink_filename) is None:
69
- return None
70
-
71
- if not (symlink_path := self.dirpath / symlink_filename).exists():
72
- return None
73
-
74
- return symlink_path
@@ -1,21 +1,23 @@
1
+ import logging
1
2
  import re
2
3
  from datetime import timedelta
3
- from logging import getLogger
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Literal
6
6
 
7
+ from lightning.pytorch import Trainer
7
8
  from lightning.pytorch.callbacks.model_checkpoint import (
8
9
  ModelCheckpoint as _ModelCheckpoint,
9
10
  )
10
11
  from typing_extensions import override
11
12
 
13
+ from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
12
14
  from ..metrics import MetricConfig
13
15
  from .base import CallbackConfigBase
14
16
 
15
17
  if TYPE_CHECKING:
16
18
  from ..model.config import BaseConfig
17
19
 
18
- log = getLogger(__name__)
20
+ log = logging.getLogger(__name__)
19
21
 
20
22
 
21
23
  def _convert_string(input_string: str):
@@ -158,6 +160,8 @@ class ModelCheckpointCallbackConfig(CallbackConfigBase):
158
160
 
159
161
 
160
162
  class ModelCheckpoint(_ModelCheckpoint):
163
+ CHECKPOINT_NAME_LAST = "best"
164
+
161
165
  @override
162
166
  def __init__(
163
167
  self,
@@ -185,3 +189,17 @@ class ModelCheckpoint(_ModelCheckpoint):
185
189
  save_on_train_epoch_end=self.config.save_on_train_epoch_end,
186
190
  enable_version_counter=self.config.enable_version_counter,
187
191
  )
192
+
193
+ @override
194
+ def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
195
+ return _link_checkpoint(
196
+ trainer,
197
+ filepath,
198
+ linkpath,
199
+ barrier=True,
200
+ metadata=True,
201
+ )
202
+
203
+ @override
204
+ def _remove_checkpoint(self, trainer: Trainer, filepath: str):
205
+ return _remove_checkpoint(trainer, filepath, remove_metadata=True)
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal, cast
3
3
 
4
4
  import torch
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from .base import CallbackConfigBase
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  def grad_norm(
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
3
3
 
4
4
  from typing_extensions import NotRequired, override
@@ -6,7 +6,7 @@ from typing_extensions import NotRequired, override
6
6
  from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
7
7
  from .base import CallbackConfigBase
8
8
 
9
- log = getLogger(__name__)
9
+ log = logging.getLogger(__name__)
10
10
 
11
11
 
12
12
  class ThroughputMonitorBatchStats(TypedDict):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal, Protocol, cast, runtime_checkable
3
3
 
4
4
  import torch.nn as nn
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from .base import CallbackConfigBase
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  @runtime_checkable
@@ -1,6 +1,6 @@
1
1
  import heapq
2
+ import logging
2
3
  from functools import cached_property
3
- from logging import getLogger
4
4
  from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  import numpy as np
@@ -10,7 +10,7 @@ from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
10
10
  from torch.utils.data import BatchSampler, Dataset, DistributedSampler
11
11
  from typing_extensions import override
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def _all_gather(tensor: torch.Tensor, device: torch.device | None = None):
nshtrainer/model/base.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import inspect
2
+ import logging
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import MutableMapping
4
- from logging import getLogger
5
5
  from typing import IO, TYPE_CHECKING, Any, Generic, cast
6
6
 
7
7
  import torch
@@ -21,7 +21,7 @@ from .modules.profiler import ProfilerMixin
21
21
  from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
22
22
  from .modules.shared_parameters import SharedParametersModuleMixin
23
23
 
24
- log = getLogger(__name__)
24
+ log = logging.getLogger(__name__)
25
25
 
26
26
  THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
27
27
 
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import logging
2
3
  import os
3
4
  import string
4
5
  import time
@@ -6,7 +7,6 @@ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
  from collections.abc import Iterable, Sequence
8
9
  from datetime import timedelta
9
- from logging import getLogger
10
10
  from pathlib import Path
11
11
  from typing import (
12
12
  Annotated,
@@ -46,7 +46,7 @@ from ..callbacks.base import CallbackConfigBase
46
46
  from ..metrics import MetricConfig
47
47
  from ._environment import EnvironmentConfig
48
48
 
49
- log = getLogger(__name__)
49
+ log = logging.getLogger(__name__)
50
50
 
51
51
 
52
52
  class IdSeedWarning(Warning):
@@ -1,6 +1,6 @@
1
+ import logging
1
2
  from collections import abc
2
3
  from collections.abc import Callable, Iterable
3
- from logging import getLogger
4
4
  from typing import Any, TypeAlias, cast, final
5
5
 
6
6
  from lightning.pytorch import Callback, LightningModule
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from ...util.typing_utils import mixin_base_type
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
  CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
15
15
 
@@ -1,9 +1,9 @@
1
- from logging import getLogger
1
+ import logging
2
2
 
3
3
  import torch
4
4
  import torch.distributed
5
5
 
6
- log = getLogger(__name__)
6
+ log = logging.getLogger(__name__)
7
7
 
8
8
 
9
9
  class DebugModuleMixin:
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  from collections.abc import Mapping
2
- from logging import getLogger
3
3
  from typing import cast
4
4
 
5
5
  import torch
@@ -14,7 +14,7 @@ from ...util.typing_utils import mixin_base_type
14
14
  from ..config import BaseConfig
15
15
  from .callback import CallbackModuleMixin
16
16
 
17
- log = getLogger(__name__)
17
+ log = logging.getLogger(__name__)
18
18
 
19
19
 
20
20
  def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  from collections.abc import Sequence
2
- from logging import getLogger
3
3
  from typing import cast
4
4
 
5
5
  import torch.nn as nn
@@ -10,7 +10,7 @@ from ...util.typing_utils import mixin_base_type
10
10
  from ..config import BaseConfig
11
11
  from .callback import CallbackRegistrarModuleMixin
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
@@ -1,8 +1,8 @@
1
+ import logging
1
2
  import os
2
3
  from contextlib import contextmanager
3
- from logging import getLogger
4
4
 
5
- log = getLogger(__name__)
5
+ log = logging.getLogger(__name__)
6
6
 
7
7
 
8
8
  @contextmanager
nshtrainer/util/seed.py CHANGED
@@ -1,8 +1,8 @@
1
- from logging import getLogger
1
+ import logging
2
2
 
3
3
  import lightning.fabric.utilities.seed as LS
4
4
 
5
- log = getLogger(__name__)
5
+ log = logging.getLogger(__name__)
6
6
 
7
7
 
8
8
  def seed_everything(seed: int | None, *, workers: bool = False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.9
3
+ Version: 0.10.11
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,7 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
3
- nshtrainer/_checkpoint/metadata.py,sha256=C7je_soYyEbZjiq7p2_pSVFkgcXnz2J2H5sMy8oskx0,3051
3
+ nshtrainer/_checkpoint/metadata.py,sha256=B6kPmWsq2TQh0gTzBx-1pLIwTVEs_Qw5v0nHEeTBdO4,5636
4
+ nshtrainer/_checkpoint/saver.py,sha256=KZp9ITUVHwj2Ttu81zXKdlS_h-fKkHearspwuAijDpM,1501
4
5
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
5
6
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
6
7
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
@@ -9,22 +10,22 @@ nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gO
9
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
11
  nshtrainer/callbacks/actsave.py,sha256=aY6T_NAzaFAVU8WMHOXnWL5wd2bi8eVxeU2S0iAs70c,4446
11
12
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
12
- nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
13
+ nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
13
14
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
14
- nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
15
+ nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
15
16
  nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
16
17
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
17
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=zCRAUsqW-2PaoIwVKlXOqdh2uF_B_YUUTmQO1wSomR8,2489
18
+ nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=t4vWa4PvJDO3rKXKZbuegm7iLl7xCEd17wNif0Bp-BA,4138
18
19
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
19
- nshtrainer/callbacks/model_checkpoint.py,sha256=N0raLsHlCVSbO3QU5eNFUXUDqxxW3C73oQwceMnFE_k,5955
20
- nshtrainer/callbacks/norm_logging.py,sha256=EWyrfkp8iHjQi9iAAXHxb0xStw2RwkdpKG2_gLarQRA,6281
20
+ nshtrainer/callbacks/model_checkpoint.py,sha256=MaDkD8Ismcj8u6l2flCFlqJR3-k1Tc4xzhxNWNux4n0,6556
21
+ nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
21
22
  nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKhA9qBEosiV8ZNhhZ6lI,3355
22
23
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
23
- nshtrainer/callbacks/throughput_monitor.py,sha256=4EF3b79HdHiRgBGIFDyD4O1oywb5h1tV8nml7NuuDjU,1845
24
+ nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
24
25
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
25
- nshtrainer/callbacks/wandb_watch.py,sha256=0JKDEPSDmiKverFm00lFfufOvtvr49akFXScvdUQnqc,2930
26
+ nshtrainer/callbacks/wandb_watch.py,sha256=EJ93mtJlph4BZsXh8HJPNiw2VNSm2N6TOwpCwqRAeKI,2923
26
27
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
27
- nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
28
+ nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
28
29
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
29
30
  nshtrainer/ll/__init__.py,sha256=dD0ISxHJ2lg1HLSM0b3db7TBlsPpQCtChnuYO-c2oqI,2635
30
31
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -51,15 +52,15 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
51
52
  nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
52
53
  nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
53
54
  nshtrainer/model/_environment.py,sha256=oTtecQeF5oY2RV7UkkSLnzDy3clz4AUkf9oocD6-e54,23115
54
- nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
55
- nshtrainer/model/config.py,sha256=FRjn2pVbCzitBeIFcKXXPQ0TCmqvAwSylAvfedf1NHg,53356
56
- nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
57
- nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
55
+ nshtrainer/model/base.py,sha256=WtCj0-nLWeW04Tu2TWVjIq0D-jW_kMN2hg--4VWVnvE,17505
56
+ nshtrainer/model/config.py,sha256=orGBrp8TXnHksfAzXxNJVDdo0X_iIn_nda6BZDS9N70,53349
57
+ nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
58
+ nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
58
59
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
59
60
  nshtrainer/model/modules/logger.py,sha256=YYhehQysqTjuVFcd_EREYDh57CIlezidFBS2Ohp_xKo,5661
60
61
  nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
61
- nshtrainer/model/modules/rlp_sanity_checks.py,sha256=o6gUceFwsuDHmL8eLOYuT3JGXFzq_qc4awl2RWaBygU,8900
62
- nshtrainer/model/modules/shared_parameters.py,sha256=mD5wrlBE3c025vzVdTpnSyC8yxzuI-aUWMmPhqPT0a0,2694
62
+ nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
63
+ nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
63
64
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
64
65
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
65
66
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
@@ -73,11 +74,11 @@ nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2f
73
74
  nshtrainer/trainer/checkpoint_connector.py,sha256=xoqI2dcPnlNFPPLVIU6dBOvRPC9PtfX5qu__xV1lx0Y,2124
74
75
  nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
75
76
  nshtrainer/trainer/trainer.py,sha256=tFyzIsF8c-FABTH6wwDOR9y8kydVJqeVO7PDNFMvhSU,16950
76
- nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
77
- nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
77
+ nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
78
+ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
78
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
79
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
80
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
81
- nshtrainer-0.10.9.dist-info/METADATA,sha256=zpMnR1Jwc9kX7BhnHT6NA3nsPUwQzMAraBPOGsUSD1w,695
82
- nshtrainer-0.10.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
83
- nshtrainer-0.10.9.dist-info/RECORD,,
82
+ nshtrainer-0.10.11.dist-info/METADATA,sha256=9WAsp25_csjDcchr5X22g7ocQpQ-d-ewB3gS9EAZSE8,696
83
+ nshtrainer-0.10.11.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.11.dist-info/RECORD,,