nshtrainer 0.11.6__py3-none-any.whl → 0.11.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/loader.py +4 -4
- nshtrainer/_checkpoint/metadata.py +37 -35
- nshtrainer/_checkpoint/saver.py +18 -28
- nshtrainer/callbacks/__init__.py +3 -0
- nshtrainer/callbacks/checkpoint/__init__.py +4 -0
- nshtrainer/callbacks/checkpoint/_base.py +175 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +29 -155
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +39 -0
- nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +42 -30
- nshtrainer/callbacks/checkpoint/model_checkpoint.py +4 -13
- nshtrainer/model/config.py +3 -1
- {nshtrainer-0.11.6.dist-info → nshtrainer-0.11.8.dist-info}/METADATA +1 -1
- {nshtrainer-0.11.6.dist-info → nshtrainer-0.11.8.dist-info}/RECORD +14 -12
- {nshtrainer-0.11.6.dist-info → nshtrainer-0.11.8.dist-info}/WHEEL +0 -0
nshtrainer/_checkpoint/loader.py
CHANGED
|
@@ -236,7 +236,7 @@ def _load_ckpt_meta(
|
|
|
236
236
|
error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
|
|
237
237
|
match on_error:
|
|
238
238
|
case "warn":
|
|
239
|
-
log.
|
|
239
|
+
log.warning(error_msg)
|
|
240
240
|
case "raise":
|
|
241
241
|
raise ValueError(error_msg)
|
|
242
242
|
case _:
|
|
@@ -325,13 +325,13 @@ def _resolve_checkpoint(
|
|
|
325
325
|
),
|
|
326
326
|
]
|
|
327
327
|
if not candidates:
|
|
328
|
-
log.
|
|
328
|
+
log.warning(
|
|
329
329
|
"No checkpoint candidates found for `best` checkpoint strategy."
|
|
330
330
|
)
|
|
331
331
|
continue
|
|
332
332
|
|
|
333
333
|
if (metric := strategy.metric or root_config.primary_metric) is None:
|
|
334
|
-
log.
|
|
334
|
+
log.warning(
|
|
335
335
|
"No metric specified for `best` checkpoint strategy, "
|
|
336
336
|
"and no primary metric is set in the configuration. "
|
|
337
337
|
"Skipping strategy."
|
|
@@ -360,7 +360,7 @@ def _resolve_checkpoint(
|
|
|
360
360
|
),
|
|
361
361
|
]
|
|
362
362
|
if not candidates:
|
|
363
|
-
log.
|
|
363
|
+
log.warning(
|
|
364
364
|
"No checkpoint candidates found for `last` checkpoint strategy."
|
|
365
365
|
)
|
|
366
366
|
continue
|
|
@@ -4,7 +4,7 @@ import logging
|
|
|
4
4
|
import shutil
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
|
8
8
|
|
|
9
9
|
import nshconfig as C
|
|
10
10
|
import numpy as np
|
|
@@ -20,10 +20,11 @@ log = logging.getLogger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
METADATA_PATH_SUFFIX = ".metadata.json"
|
|
23
|
-
HPARAMS_PATH_SUFFIX = ".hparams.json"
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class CheckpointMetadata(C.Config):
|
|
26
|
+
PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
|
|
27
|
+
|
|
27
28
|
checkpoint_path: Path
|
|
28
29
|
checkpoint_filename: str
|
|
29
30
|
|
|
@@ -39,6 +40,8 @@ class CheckpointMetadata(C.Config):
|
|
|
39
40
|
metrics: dict[str, Any]
|
|
40
41
|
environment: EnvironmentConfig
|
|
41
42
|
|
|
43
|
+
hparams: dict[str, Any] | None
|
|
44
|
+
|
|
42
45
|
@classmethod
|
|
43
46
|
def from_file(cls, path: Path):
|
|
44
47
|
return cls.model_validate_json(path.read_text())
|
|
@@ -55,7 +58,10 @@ class CheckpointMetadata(C.Config):
|
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
def _generate_checkpoint_metadata(
|
|
58
|
-
config: "BaseConfig",
|
|
61
|
+
config: "BaseConfig",
|
|
62
|
+
trainer: "Trainer",
|
|
63
|
+
checkpoint_path: Path,
|
|
64
|
+
metadata_path: Path,
|
|
59
65
|
):
|
|
60
66
|
checkpoint_timestamp = datetime.datetime.now()
|
|
61
67
|
start_timestamp = trainer.start_time()
|
|
@@ -70,7 +76,11 @@ def _generate_checkpoint_metadata(
|
|
|
70
76
|
metrics[name] = metric
|
|
71
77
|
|
|
72
78
|
return CheckpointMetadata(
|
|
73
|
-
checkpoint_path=checkpoint_path,
|
|
79
|
+
# checkpoint_path=checkpoint_path,
|
|
80
|
+
# We should store the path as a relative path
|
|
81
|
+
# to the metadata file to avoid issues with
|
|
82
|
+
# moving the checkpoint directory
|
|
83
|
+
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
|
74
84
|
checkpoint_filename=checkpoint_path.name,
|
|
75
85
|
run_id=config.id,
|
|
76
86
|
name=config.run_name,
|
|
@@ -84,6 +94,7 @@ def _generate_checkpoint_metadata(
|
|
|
84
94
|
training_time=training_time,
|
|
85
95
|
metrics=metrics,
|
|
86
96
|
environment=config.environment,
|
|
97
|
+
hparams=config.model_dump(mode="json"),
|
|
87
98
|
)
|
|
88
99
|
|
|
89
100
|
|
|
@@ -93,36 +104,28 @@ def _write_checkpoint_metadata(
|
|
|
93
104
|
checkpoint_path: Path,
|
|
94
105
|
):
|
|
95
106
|
config = cast("BaseConfig", model.config)
|
|
96
|
-
|
|
107
|
+
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
108
|
+
metadata = _generate_checkpoint_metadata(
|
|
109
|
+
config, trainer, checkpoint_path, metadata_path
|
|
110
|
+
)
|
|
97
111
|
|
|
98
112
|
# Write the metadata to the checkpoint directory
|
|
99
113
|
try:
|
|
100
|
-
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
101
114
|
metadata_path.write_text(metadata.model_dump_json(indent=4))
|
|
102
115
|
except Exception as e:
|
|
103
116
|
log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
|
|
104
117
|
else:
|
|
105
118
|
log.debug(f"Checkpoint metadata written to {checkpoint_path}")
|
|
106
119
|
|
|
107
|
-
|
|
120
|
+
|
|
121
|
+
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
122
|
+
path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
108
123
|
try:
|
|
109
|
-
|
|
110
|
-
hparams_path.write_text(config.model_dump_json(indent=4))
|
|
124
|
+
path.unlink(missing_ok=True)
|
|
111
125
|
except Exception as e:
|
|
112
|
-
log.warning(f"Failed to
|
|
126
|
+
log.warning(f"Failed to remove {path}: {e}")
|
|
113
127
|
else:
|
|
114
|
-
log.debug(f"
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
118
|
-
for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
|
|
119
|
-
path = checkpoint_path.with_suffix(suffix)
|
|
120
|
-
try:
|
|
121
|
-
path.unlink(missing_ok=True)
|
|
122
|
-
except Exception as e:
|
|
123
|
-
log.warning(f"Failed to remove {path}: {e}")
|
|
124
|
-
else:
|
|
125
|
-
log.debug(f"Removed {path}")
|
|
128
|
+
log.debug(f"Removed {path}")
|
|
126
129
|
|
|
127
130
|
|
|
128
131
|
def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
|
|
@@ -130,20 +133,19 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
130
133
|
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
131
134
|
|
|
132
135
|
# Link the metadata files to the new checkpoint
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
+
path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
137
|
+
linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
138
|
+
try:
|
|
136
139
|
try:
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
log.debug(f"Linked {path} to {linked_path}")
|
|
140
|
+
linked_path.symlink_to(path)
|
|
141
|
+
except OSError:
|
|
142
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
143
|
+
# fall back to copying the file
|
|
144
|
+
shutil.copy(path, linked_path)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
log.warning(f"Failed to link {path} to {linked_path}: {e}")
|
|
147
|
+
else:
|
|
148
|
+
log.debug(f"Linked {path} to {linked_path}")
|
|
147
149
|
|
|
148
150
|
|
|
149
151
|
def _sort_ckpts_by_metadata(
|
nshtrainer/_checkpoint/saver.py
CHANGED
|
@@ -8,11 +8,9 @@ from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def _link_checkpoint(
|
|
11
|
-
trainer: Trainer,
|
|
12
11
|
filepath: str | Path | os.PathLike,
|
|
13
12
|
linkpath: str | Path | os.PathLike,
|
|
14
13
|
*,
|
|
15
|
-
barrier: bool,
|
|
16
14
|
metadata: bool,
|
|
17
15
|
):
|
|
18
16
|
if not isinstance(filepath, Path):
|
|
@@ -20,26 +18,23 @@ def _link_checkpoint(
|
|
|
20
18
|
if not isinstance(linkpath, Path):
|
|
21
19
|
linkpath = Path(linkpath)
|
|
22
20
|
|
|
23
|
-
if
|
|
24
|
-
if linkpath.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
_remove_checkpoint_metadata(linkpath)
|
|
21
|
+
if linkpath.exists():
|
|
22
|
+
if linkpath.is_symlink() or linkpath.is_file():
|
|
23
|
+
linkpath.unlink()
|
|
24
|
+
elif linkpath.is_dir():
|
|
25
|
+
shutil.rmtree(linkpath)
|
|
26
|
+
_remove_checkpoint_metadata(linkpath)
|
|
30
27
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
28
|
+
try:
|
|
29
|
+
target_path = filepath.relative_to(linkpath.parent)
|
|
30
|
+
linkpath.symlink_to(target_path)
|
|
31
|
+
except OSError:
|
|
32
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
33
|
+
# fall back to copying the file
|
|
34
|
+
shutil.copy(filepath, linkpath)
|
|
38
35
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
if barrier:
|
|
42
|
-
trainer.strategy.barrier()
|
|
36
|
+
if metadata:
|
|
37
|
+
_link_checkpoint_metadata(filepath, linkpath)
|
|
43
38
|
|
|
44
39
|
|
|
45
40
|
def _remove_checkpoint(
|
|
@@ -47,15 +42,10 @@ def _remove_checkpoint(
|
|
|
47
42
|
filepath: str | Path | os.PathLike,
|
|
48
43
|
*,
|
|
49
44
|
metadata: bool,
|
|
50
|
-
barrier: bool,
|
|
51
45
|
):
|
|
52
46
|
if not isinstance(filepath, Path):
|
|
53
47
|
filepath = Path(filepath)
|
|
54
48
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
_remove_checkpoint_metadata(filepath)
|
|
59
|
-
|
|
60
|
-
if barrier:
|
|
61
|
-
trainer.strategy.barrier()
|
|
49
|
+
trainer.strategy.remove_checkpoint(filepath)
|
|
50
|
+
if metadata:
|
|
51
|
+
_remove_checkpoint_metadata(filepath)
|
nshtrainer/callbacks/__init__.py
CHANGED
|
@@ -6,6 +6,8 @@ from . import checkpoint as checkpoint
|
|
|
6
6
|
from .base import CallbackConfigBase as CallbackConfigBase
|
|
7
7
|
from .checkpoint import BestCheckpoint as BestCheckpoint
|
|
8
8
|
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
9
|
+
from .checkpoint import LastCheckpoint as LastCheckpoint
|
|
10
|
+
from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
9
11
|
from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
10
12
|
from .checkpoint import (
|
|
11
13
|
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
@@ -46,6 +48,7 @@ CallbackConfig = Annotated[
|
|
|
46
48
|
| GradientSkippingConfig
|
|
47
49
|
| EMAConfig
|
|
48
50
|
| BestCheckpointCallbackConfig
|
|
51
|
+
| LastCheckpointCallbackConfig
|
|
49
52
|
| ModelCheckpointCallbackConfig
|
|
50
53
|
| LatestEpochCheckpointCallbackConfig
|
|
51
54
|
| OnExceptionCheckpointCallbackConfig
|
|
@@ -2,6 +2,10 @@ from .best_checkpoint import BestCheckpoint as BestCheckpoint
|
|
|
2
2
|
from .best_checkpoint import (
|
|
3
3
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
|
4
4
|
)
|
|
5
|
+
from .last_checkpoint import LastCheckpoint as LastCheckpoint
|
|
6
|
+
from .last_checkpoint import (
|
|
7
|
+
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
8
|
+
)
|
|
5
9
|
from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
6
10
|
from .latest_epoch_checkpoint import (
|
|
7
11
|
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch import Trainer
|
|
9
|
+
from lightning.pytorch.callbacks import Checkpoint
|
|
10
|
+
from typing_extensions import TypeVar, override
|
|
11
|
+
|
|
12
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
|
13
|
+
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
|
+
from ..base import CallbackConfigBase
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ...model.config import BaseConfig
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
23
|
+
dirpath: str | Path | None = None
|
|
24
|
+
"""Directory path to save the checkpoint file."""
|
|
25
|
+
|
|
26
|
+
filename: str | None = None
|
|
27
|
+
"""Checkpoint filename. This must not include the extension.
|
|
28
|
+
If None, the default filename will be used."""
|
|
29
|
+
|
|
30
|
+
save_weights_only: bool = False
|
|
31
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
32
|
+
|
|
33
|
+
save_symlink: bool = True
|
|
34
|
+
"""Whether to create a symlink to the saved checkpoint."""
|
|
35
|
+
|
|
36
|
+
topk: int | Literal["all"] = 1
|
|
37
|
+
"""The number of checkpoints to keep."""
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def create_checkpoint(
|
|
41
|
+
self,
|
|
42
|
+
root_config: "BaseConfig",
|
|
43
|
+
dirpath: Path,
|
|
44
|
+
) -> "CheckpointBase": ...
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def create_callbacks(self, root_config):
|
|
48
|
+
dirpath = Path(
|
|
49
|
+
self.dirpath
|
|
50
|
+
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
yield self.create_checkpoint(root_config, dirpath)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
60
|
+
def __init__(self, config: TConfig, dirpath: Path):
|
|
61
|
+
super().__init__()
|
|
62
|
+
|
|
63
|
+
self.config = config
|
|
64
|
+
self.dirpath = dirpath / self.name()
|
|
65
|
+
self.symlink_dirpath = dirpath
|
|
66
|
+
|
|
67
|
+
self._last_global_step_saved = 0
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def default_filename(self) -> str: ...
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def name(self) -> str: ...
|
|
74
|
+
|
|
75
|
+
def extension(self) -> str:
|
|
76
|
+
return ".ckpt"
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
|
|
80
|
+
|
|
81
|
+
def symlink_path(self):
|
|
82
|
+
if not self.config.save_symlink:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
return self.symlink_dirpath / f"{self.name()}{self.extension()}"
|
|
86
|
+
|
|
87
|
+
def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
|
|
88
|
+
if (filename := self.config.filename) is None:
|
|
89
|
+
filename = self.default_filename()
|
|
90
|
+
filename = filename.format(**current_metrics)
|
|
91
|
+
return self.dirpath / f"{filename}{self.extension()}"
|
|
92
|
+
|
|
93
|
+
def remove_old_checkpoints(self, trainer: Trainer):
|
|
94
|
+
if (topk := self.config.topk) == "all":
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# Get all the checkpoint metadata
|
|
98
|
+
metas = [
|
|
99
|
+
CheckpointMetadata.from_file(p)
|
|
100
|
+
for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# Sort by the topk sort key
|
|
104
|
+
metas = sorted(metas, key=self.topk_sort_key)
|
|
105
|
+
|
|
106
|
+
# Now, the metas are sorted from the best to the worst,
|
|
107
|
+
# so we can remove the worst checkpoints
|
|
108
|
+
for meta in metas[topk:]:
|
|
109
|
+
if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
|
|
110
|
+
log.warning(
|
|
111
|
+
f"Checkpoint file not found: {old_ckpt_path}\n"
|
|
112
|
+
"Skipping removal of the checkpoint metadata."
|
|
113
|
+
)
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
_remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
|
117
|
+
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
|
118
|
+
|
|
119
|
+
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
|
120
|
+
current_metrics: dict[str, Any] = {
|
|
121
|
+
"epoch": trainer.current_epoch,
|
|
122
|
+
"step": trainer.global_step,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
for name, value in trainer.callback_metrics.items():
|
|
126
|
+
match value:
|
|
127
|
+
case torch.Tensor() if value.numel() == 1:
|
|
128
|
+
value = value.detach().cpu().item()
|
|
129
|
+
case np.ndarray() if value.size == 1:
|
|
130
|
+
value = value.item()
|
|
131
|
+
case _:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
current_metrics[name] = value
|
|
135
|
+
|
|
136
|
+
return current_metrics
|
|
137
|
+
|
|
138
|
+
def save_checkpoints(self, trainer: Trainer):
|
|
139
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
# Save the new checkpoint
|
|
143
|
+
filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
|
|
144
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
145
|
+
|
|
146
|
+
if trainer.is_global_zero:
|
|
147
|
+
# Remove old checkpoints
|
|
148
|
+
self.remove_old_checkpoints(trainer)
|
|
149
|
+
|
|
150
|
+
# Create the latest symlink
|
|
151
|
+
if (symlink_filename := self.symlink_path()) is not None:
|
|
152
|
+
symlink_path = self.dirpath / symlink_filename
|
|
153
|
+
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
154
|
+
log.debug(f"Created latest symlink: {symlink_path}")
|
|
155
|
+
|
|
156
|
+
# Barrier to ensure all processes have saved the checkpoint,
|
|
157
|
+
# deleted the old checkpoints, and created the symlink before continuing
|
|
158
|
+
trainer.strategy.barrier()
|
|
159
|
+
|
|
160
|
+
# Set the last global step saved
|
|
161
|
+
self._last_global_step_saved = trainer.global_step
|
|
162
|
+
|
|
163
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
164
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
165
|
+
|
|
166
|
+
return (
|
|
167
|
+
bool(
|
|
168
|
+
getattr(trainer, "fast_dev_run", False)
|
|
169
|
+
) # disable checkpointing with fast_dev_run
|
|
170
|
+
or trainer.state.fn
|
|
171
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
172
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
173
|
+
or self._last_global_step_saved
|
|
174
|
+
== trainer.global_step # already saved at the last step
|
|
175
|
+
)
|
|
@@ -1,47 +1,27 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
-
from
|
|
7
|
-
|
|
6
|
+
from typing_extensions import final, override
|
|
7
|
+
|
|
8
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
|
8
9
|
|
|
9
|
-
from ..._checkpoint.metadata import _sort_ckpts_by_metadata
|
|
10
|
-
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
11
10
|
from ...metrics._config import MetricConfig
|
|
12
|
-
from
|
|
11
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
|
13
12
|
|
|
14
13
|
log = logging.getLogger(__name__)
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
|
|
16
|
+
@final
|
|
17
|
+
class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
18
18
|
name: Literal["best_checkpoint"] = "best_checkpoint"
|
|
19
19
|
|
|
20
|
-
dirpath: str | Path | None = None
|
|
21
|
-
"""Directory path to save the checkpoint file."""
|
|
22
|
-
|
|
23
|
-
filename: str = "epoch{epoch:02d}_step{step:04d}"
|
|
24
|
-
"""Checkpoint filename. This must not include the extension."""
|
|
25
|
-
|
|
26
|
-
save_weights_only: bool = False
|
|
27
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
28
|
-
|
|
29
20
|
metric: MetricConfig | None = None
|
|
30
21
|
"""Metric to monitor, or `None` to use the default metric."""
|
|
31
22
|
|
|
32
|
-
best_symlink_filename: str | None = "best"
|
|
33
|
-
"""Filename for the best symlink. If None, no symlink will be created."""
|
|
34
|
-
|
|
35
|
-
save_top_k: int | Literal["all"] = 1
|
|
36
|
-
"""The number of best checkpoints to keep."""
|
|
37
|
-
|
|
38
23
|
@override
|
|
39
|
-
def
|
|
40
|
-
dirpath = Path(
|
|
41
|
-
self.dirpath
|
|
42
|
-
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
43
|
-
)
|
|
44
|
-
|
|
24
|
+
def create_checkpoint(self, root_config, dirpath):
|
|
45
25
|
# Resolve metric
|
|
46
26
|
if (metric := self.metric) is None and (
|
|
47
27
|
metric := root_config.primary_metric
|
|
@@ -50,147 +30,41 @@ class BestCheckpointCallbackConfig(CallbackConfigBase):
|
|
|
50
30
|
"No metric provided and no primary metric found in the root config"
|
|
51
31
|
)
|
|
52
32
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@property
|
|
56
|
-
def _save_top_k_value(self):
|
|
57
|
-
return float("inf" if self.save_top_k == "all" else self.save_top_k)
|
|
33
|
+
return BestCheckpoint(self, dirpath, metric)
|
|
58
34
|
|
|
59
35
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
36
|
+
@final
|
|
37
|
+
class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
38
|
+
@property
|
|
39
|
+
def _metric_name_normalized(self):
|
|
40
|
+
return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
|
|
63
41
|
|
|
42
|
+
@override
|
|
64
43
|
def __init__(
|
|
65
44
|
self,
|
|
66
45
|
config: BestCheckpointCallbackConfig,
|
|
67
|
-
metric: MetricConfig,
|
|
68
46
|
dirpath: Path,
|
|
47
|
+
metric: MetricConfig,
|
|
69
48
|
):
|
|
70
|
-
super().__init__()
|
|
71
|
-
self.config = config
|
|
49
|
+
super().__init__(config, dirpath)
|
|
72
50
|
self.metric = metric
|
|
73
|
-
self.dirpath = dirpath
|
|
74
|
-
|
|
75
|
-
self._last_global_step_saved = 0 # no need to save when no steps were taken
|
|
76
51
|
|
|
77
52
|
@override
|
|
78
|
-
def
|
|
79
|
-
self.
|
|
53
|
+
def name(self):
|
|
54
|
+
return f"best_{self._metric_name_normalized}"
|
|
80
55
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
return f"{filename}{self.EXTENSION}"
|
|
85
|
-
|
|
86
|
-
def _ckpt_path(self, trainer: Trainer):
|
|
87
|
-
filename = self.config.filename.format(
|
|
88
|
-
epoch=trainer.current_epoch, step=trainer.global_step
|
|
89
|
-
)
|
|
90
|
-
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
91
|
-
return self.dirpath / filename
|
|
92
|
-
|
|
93
|
-
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
94
|
-
for ckpt_path in ckpt_paths:
|
|
95
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
56
|
+
@override
|
|
57
|
+
def default_filename(self):
|
|
58
|
+
return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
|
|
96
59
|
|
|
97
|
-
|
|
98
|
-
|
|
60
|
+
@override
|
|
61
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
|
62
|
+
return metadata.metrics.get(
|
|
99
63
|
self.metric.validation_monitor,
|
|
100
64
|
float("-inf" if self.metric.mode == "max" else "inf"),
|
|
101
65
|
)
|
|
102
66
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
key=lambda meta, _: self._get_metric_value(meta.metrics),
|
|
108
|
-
reverse=(self.metric.mode == "min"),
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
|
|
112
|
-
# Resolve the symlink filename
|
|
113
|
-
if (symlink_filename := self._best_symlink_filename()) is None:
|
|
114
|
-
return
|
|
115
|
-
|
|
116
|
-
# If the symlink already exists and points to the best checkpoint,
|
|
117
|
-
# then we don't need to create a new symlink.
|
|
118
|
-
symlink_path = self.dirpath / symlink_filename
|
|
119
|
-
if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
|
|
120
|
-
return
|
|
121
|
-
|
|
122
|
-
_link_checkpoint(
|
|
123
|
-
trainer,
|
|
124
|
-
best_ckpt_path,
|
|
125
|
-
symlink_path,
|
|
126
|
-
metadata=True,
|
|
127
|
-
barrier=False,
|
|
128
|
-
)
|
|
129
|
-
log.debug(f"Created best symlink: {symlink_path}")
|
|
130
|
-
|
|
131
|
-
def _save_best_checkpoint(self, trainer: Trainer):
|
|
132
|
-
# Skip saving the checkpoint if we're not in the fitting state
|
|
133
|
-
if self._should_skip_saving_checkpoint(trainer):
|
|
134
|
-
return
|
|
135
|
-
|
|
136
|
-
# Get the current metric value
|
|
137
|
-
if (current := self._get_metric_value(trainer.callback_metrics)) is None:
|
|
138
|
-
log.warning(
|
|
139
|
-
f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
|
|
140
|
-
)
|
|
141
|
-
return
|
|
142
|
-
|
|
143
|
-
# Get sorted checkpoints
|
|
144
|
-
sorted_ckpts = self._sorted_ckpts()
|
|
145
|
-
|
|
146
|
-
# If the current model is worse than the worst checkpoint,
|
|
147
|
-
# and we have already saved the maximum number of checkpoints,
|
|
148
|
-
# then don't save the current model.
|
|
149
|
-
if len(
|
|
150
|
-
sorted_ckpts
|
|
151
|
-
) >= self.config._save_top_k_value and not self.metric.is_better(
|
|
152
|
-
current,
|
|
153
|
-
self._get_metric_value(sorted_ckpts[-1][0].metrics),
|
|
154
|
-
):
|
|
155
|
-
return
|
|
156
|
-
|
|
157
|
-
# Save the current model
|
|
158
|
-
filepath = self._ckpt_path(trainer)
|
|
159
|
-
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
160
|
-
log.debug(f"Saved best checkpoint: {filepath}")
|
|
161
|
-
|
|
162
|
-
# Remove worst checkpoint if we've reached save_top_k
|
|
163
|
-
# NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
|
|
164
|
-
if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
|
|
165
|
-
# Get the sorted checkpoints again because now we have added a new checkpoint.
|
|
166
|
-
# We could optimize this by adding the new checkpoint to the sorted list,
|
|
167
|
-
# and then sorting it in place, but this is simpler.
|
|
168
|
-
sorted_ckpts = self._sorted_ckpts()
|
|
169
|
-
self._remove_checkpoints(
|
|
170
|
-
trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
# Create symlink to best model
|
|
174
|
-
if sorted_ckpts:
|
|
175
|
-
_, best_ckpt_path = sorted_ckpts[0]
|
|
176
|
-
self._create_symlink(trainer, best_ckpt_path)
|
|
177
|
-
|
|
178
|
-
# Update the last global step saved
|
|
179
|
-
self._last_global_step_saved = trainer.global_step
|
|
180
|
-
|
|
181
|
-
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
182
|
-
trainer.strategy.barrier()
|
|
183
|
-
|
|
184
|
-
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
185
|
-
from lightning.pytorch.trainer.states import TrainerFn
|
|
186
|
-
|
|
187
|
-
return (
|
|
188
|
-
bool(
|
|
189
|
-
getattr(trainer, "fast_dev_run", False)
|
|
190
|
-
) # disable checkpointing with fast_dev_run
|
|
191
|
-
or trainer.state.fn
|
|
192
|
-
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
193
|
-
or trainer.sanity_checking # don't save anything during sanity check
|
|
194
|
-
or self._last_global_step_saved
|
|
195
|
-
== trainer.global_step # already saved at the last step
|
|
196
|
-
)
|
|
67
|
+
# Events
|
|
68
|
+
@override
|
|
69
|
+
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
70
|
+
self.save_checkpoints(trainer)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
5
|
+
from typing_extensions import final, override
|
|
6
|
+
|
|
7
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
|
8
|
+
|
|
9
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@final
|
|
15
|
+
class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
16
|
+
name: Literal["last_checkpoint"] = "last_checkpoint"
|
|
17
|
+
|
|
18
|
+
@override
|
|
19
|
+
def create_checkpoint(self, root_config, dirpath):
|
|
20
|
+
return LastCheckpoint(self, dirpath)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@final
|
|
24
|
+
class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
25
|
+
@override
|
|
26
|
+
def name(self):
|
|
27
|
+
return "last"
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def default_filename(self):
|
|
31
|
+
return "epoch{epoch:03d}-step{step:07d}"
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
|
35
|
+
return metadata.checkpoint_timestamp
|
|
36
|
+
|
|
37
|
+
@override
|
|
38
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
39
|
+
self.save_checkpoints(trainer)
|
|
@@ -19,7 +19,7 @@ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
|
19
19
|
dirpath: str | Path | None = None
|
|
20
20
|
"""Directory path to save the checkpoint file."""
|
|
21
21
|
|
|
22
|
-
filename: str = "epoch{epoch:
|
|
22
|
+
filename: str = "epoch{epoch:03d}_step{step:07d}"
|
|
23
23
|
"""Checkpoint filename. This must not include the extension."""
|
|
24
24
|
|
|
25
25
|
save_weights_only: bool = False
|
|
@@ -51,6 +51,8 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
51
51
|
self.config = config
|
|
52
52
|
self.dirpath = dirpath
|
|
53
53
|
|
|
54
|
+
self._last_global_step_saved = 0
|
|
55
|
+
|
|
54
56
|
@override
|
|
55
57
|
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
56
58
|
self._save_new_checkpoint(trainer)
|
|
@@ -67,53 +69,63 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
67
69
|
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
68
70
|
return self.dirpath / filename
|
|
69
71
|
|
|
70
|
-
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
71
|
-
for ckpt_path in ckpt_paths:
|
|
72
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
73
|
-
|
|
74
72
|
def _remove_old_checkpoints(self, trainer: Trainer):
|
|
75
73
|
if (latest_k := self.config.latest_k) == "all":
|
|
76
74
|
return
|
|
77
75
|
|
|
78
|
-
# NOTE: We add 1 to the latest_k here because
|
|
79
|
-
# we're about to save a new checkpoint.
|
|
80
|
-
latest_k += 1
|
|
81
|
-
|
|
82
76
|
# Get all configs, ignoring the latest symlink
|
|
83
|
-
|
|
77
|
+
ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
84
78
|
# Ignore the latest symlink
|
|
85
79
|
if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
|
|
86
|
-
|
|
80
|
+
ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
|
|
87
81
|
|
|
88
82
|
# Sort by epoch, then step, then last modified
|
|
89
|
-
|
|
90
|
-
|
|
83
|
+
ckpts = _sort_ckpts_by_metadata(
|
|
84
|
+
ckpts,
|
|
91
85
|
key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
|
|
92
86
|
reverse=True,
|
|
93
87
|
)
|
|
94
88
|
|
|
95
89
|
# Remove all but the latest k checkpoints
|
|
96
|
-
|
|
97
|
-
|
|
90
|
+
# NOTE: We add 1 to the latest_k here because
|
|
91
|
+
# we're about to save a new checkpoint.
|
|
92
|
+
for _, ckpt_path in ckpts[latest_k:]:
|
|
93
|
+
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
98
94
|
|
|
99
95
|
def _save_new_checkpoint(self, trainer: Trainer):
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self._remove_old_checkpoints(trainer)
|
|
103
|
-
trainer.strategy.barrier()
|
|
96
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
97
|
+
return
|
|
104
98
|
|
|
105
99
|
# Save the new checkpoint
|
|
106
100
|
filepath = self._ckpt_path(trainer)
|
|
107
101
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
108
102
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
symlink_path
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
103
|
+
if trainer.is_global_zero:
|
|
104
|
+
# Remove old checkpoints
|
|
105
|
+
self._remove_old_checkpoints(trainer)
|
|
106
|
+
|
|
107
|
+
# Create the latest symlink
|
|
108
|
+
if (symlink_filename := self._latest_symlink_filename()) is not None:
|
|
109
|
+
symlink_path = self.dirpath / symlink_filename
|
|
110
|
+
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
111
|
+
log.debug(f"Created latest symlink: {symlink_path}")
|
|
112
|
+
|
|
113
|
+
# Set the last global step saved
|
|
114
|
+
self._last_global_step_saved = trainer.global_step
|
|
115
|
+
|
|
116
|
+
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
117
|
+
trainer.strategy.barrier()
|
|
118
|
+
|
|
119
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
120
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
121
|
+
|
|
122
|
+
return (
|
|
123
|
+
bool(
|
|
124
|
+
getattr(trainer, "fast_dev_run", False)
|
|
125
|
+
) # disable checkpointing with fast_dev_run
|
|
126
|
+
or trainer.state.fn
|
|
127
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
128
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
129
|
+
or self._last_global_step_saved
|
|
130
|
+
== trainer.global_step # already saved at the last step
|
|
131
|
+
)
|
|
@@ -198,19 +198,10 @@ class ModelCheckpoint(_ModelCheckpoint):
|
|
|
198
198
|
|
|
199
199
|
@override
|
|
200
200
|
def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
linkpath,
|
|
205
|
-
barrier=True,
|
|
206
|
-
metadata=True,
|
|
207
|
-
)
|
|
201
|
+
if trainer.is_global_zero:
|
|
202
|
+
_link_checkpoint(filepath, linkpath, metadata=True)
|
|
203
|
+
trainer.strategy.barrier()
|
|
208
204
|
|
|
209
205
|
@override
|
|
210
206
|
def _remove_checkpoint(self, trainer: Trainer, filepath: str):
|
|
211
|
-
|
|
212
|
-
trainer,
|
|
213
|
-
filepath,
|
|
214
|
-
metadata=True,
|
|
215
|
-
barrier=False,
|
|
216
|
-
)
|
|
207
|
+
_ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
|
nshtrainer/model/config.py
CHANGED
|
@@ -39,6 +39,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
|
|
|
39
39
|
from ..callbacks import (
|
|
40
40
|
BestCheckpointCallbackConfig,
|
|
41
41
|
CallbackConfig,
|
|
42
|
+
LastCheckpointCallbackConfig,
|
|
42
43
|
LatestEpochCheckpointCallbackConfig,
|
|
43
44
|
ModelCheckpointCallbackConfig,
|
|
44
45
|
OnExceptionCheckpointCallbackConfig,
|
|
@@ -773,6 +774,7 @@ class ReproducibilityConfig(C.Config):
|
|
|
773
774
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
774
775
|
ModelCheckpointCallbackConfig
|
|
775
776
|
| BestCheckpointCallbackConfig
|
|
777
|
+
| LastCheckpointCallbackConfig
|
|
776
778
|
| LatestEpochCheckpointCallbackConfig
|
|
777
779
|
| OnExceptionCheckpointCallbackConfig,
|
|
778
780
|
C.Field(discriminator="name"),
|
|
@@ -786,7 +788,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
786
788
|
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
787
789
|
# ModelCheckpointCallbackConfig(),
|
|
788
790
|
BestCheckpointCallbackConfig(),
|
|
789
|
-
|
|
791
|
+
LastCheckpointCallbackConfig(),
|
|
790
792
|
OnExceptionCheckpointCallbackConfig(),
|
|
791
793
|
]
|
|
792
794
|
"""Checkpoint callback configurations."""
|
|
@@ -1,19 +1,21 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
-
nshtrainer/_checkpoint/loader.py,sha256=
|
|
3
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
4
|
-
nshtrainer/_checkpoint/saver.py,sha256=
|
|
2
|
+
nshtrainer/_checkpoint/loader.py,sha256=DSaNR8194kWon4O1svslNsCcN_8vlyLbF0LNCPfUpzI,13789
|
|
3
|
+
nshtrainer/_checkpoint/metadata.py,sha256=n2PMGdA3Jn3BuoyDXVCF9dUnNjuuru5CL9jqMD-X4Vk,4918
|
|
4
|
+
nshtrainer/_checkpoint/saver.py,sha256=TuSAP39DOOVvSnSukQ9RitMV60JnDg6L27fMRc2uVJc,1358
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
|
|
6
6
|
nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
|
|
7
7
|
nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
|
|
8
8
|
nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
|
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=0ZYo2pX_MfTlPeoixOIHU2RgX4NcpxeQL4YrU2FE63Q,2665
|
|
10
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
13
|
-
nshtrainer/callbacks/checkpoint/__init__.py,sha256=
|
|
14
|
-
nshtrainer/callbacks/checkpoint/
|
|
15
|
-
nshtrainer/callbacks/checkpoint/
|
|
16
|
-
nshtrainer/callbacks/checkpoint/
|
|
13
|
+
nshtrainer/callbacks/checkpoint/__init__.py,sha256=2COllQ7Rfe9IWwdLjgG0Bxc-lxVJJ_aU7A9zFObTwwY,899
|
|
14
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=9HQSa-toOyjtuDldQ71gaDVBRdryAaB_nRv5Y554tIk,5938
|
|
15
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=O6d4SqIYsxpnRj_IYX8A9VLgOBwxTdz-j2FV_nn3BT8,2067
|
|
16
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CM8f37dwaYHkjQFfJNTZTzSoF45zEjFRm-Fg1CzYmP4,1037
|
|
17
|
+
nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py,sha256=II-WZHAk7leK1Vgjza0PVifrF3QetR9Nn3n1qhqtuVo,4819
|
|
18
|
+
nshtrainer/callbacks/checkpoint/model_checkpoint.py,sha256=JS1z2YuEiQxk61HgZU1jySzF_pzdfXYO54_qHo-q3CQ,6776
|
|
17
19
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=s8tOHrnb_uVqLVeV2K38ZszXrXPTEGdDVfXuXgo_KDQ,3277
|
|
18
20
|
nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
|
|
19
21
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
@@ -54,7 +56,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
54
56
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
55
57
|
nshtrainer/model/__init__.py,sha256=NpvyQHmGaHB8xdraHmm8l7kDHLmvJSgBNQKkfYqtgyI,1454
|
|
56
58
|
nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
|
|
57
|
-
nshtrainer/model/config.py,sha256=
|
|
59
|
+
nshtrainer/model/config.py,sha256=u6pUbgR_qqo_xP1lclhavKG1KbsQ9run_P0Am2XcsQw,54947
|
|
58
60
|
nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
|
|
59
61
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
60
62
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -82,6 +84,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
84
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
85
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
86
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.11.
|
|
86
|
-
nshtrainer-0.11.
|
|
87
|
-
nshtrainer-0.11.
|
|
87
|
+
nshtrainer-0.11.8.dist-info/METADATA,sha256=l5Y-rA6Kxo50NbQQJ3kKp5le6Y579NY6Jmjruomfbfs,860
|
|
88
|
+
nshtrainer-0.11.8.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
89
|
+
nshtrainer-0.11.8.dist-info/RECORD,,
|
|
File without changes
|