nshtrainer 1.3.5__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +14 -0
- nshtrainer/_checkpoint/metadata.py +4 -1
- nshtrainer/_hf_hub.py +3 -0
- nshtrainer/callbacks/checkpoint/_base.py +173 -40
- nshtrainer/callbacks/lr_monitor.py +9 -1
- nshtrainer/configs/__init__.py +1 -5
- nshtrainer/configs/trainer/__init__.py +4 -2
- nshtrainer/configs/trainer/_config/__init__.py +4 -2
- nshtrainer/trainer/_config.py +525 -73
- nshtrainer/trainer/trainer.py +11 -2
- {nshtrainer-1.3.5.dist-info → nshtrainer-1.4.0.dist-info}/METADATA +1 -1
- {nshtrainer-1.3.5.dist-info → nshtrainer-1.4.0.dist-info}/RECORD +13 -15
- nshtrainer/_directory.py +0 -72
- nshtrainer/configs/_directory/__init__.py +0 -15
- {nshtrainer-1.3.5.dist-info → nshtrainer-1.4.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
@@ -19,3 +19,17 @@ try:
|
|
19
19
|
from . import configs as configs
|
20
20
|
except BaseException:
|
21
21
|
pass
|
22
|
+
|
23
|
+
try:
|
24
|
+
from importlib.metadata import PackageNotFoundError, version
|
25
|
+
except ImportError:
|
26
|
+
# For Python <3.8
|
27
|
+
from importlib_metadata import ( # pyright: ignore[reportMissingImports]
|
28
|
+
PackageNotFoundError,
|
29
|
+
version,
|
30
|
+
)
|
31
|
+
|
32
|
+
try:
|
33
|
+
__version__ = version(__name__)
|
34
|
+
except PackageNotFoundError:
|
35
|
+
__version__ = "unknown"
|
@@ -85,6 +85,7 @@ def _generate_checkpoint_metadata(
|
|
85
85
|
trainer: Trainer,
|
86
86
|
checkpoint_path: Path,
|
87
87
|
metadata_path: Path,
|
88
|
+
compute_checksum: bool = True,
|
88
89
|
):
|
89
90
|
checkpoint_timestamp = datetime.datetime.now()
|
90
91
|
start_timestamp = trainer.start_time()
|
@@ -105,7 +106,9 @@ def _generate_checkpoint_metadata(
|
|
105
106
|
# moving the checkpoint directory
|
106
107
|
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
107
108
|
checkpoint_filename=checkpoint_path.name,
|
108
|
-
checkpoint_checksum=compute_file_checksum(checkpoint_path)
|
109
|
+
checkpoint_checksum=compute_file_checksum(checkpoint_path)
|
110
|
+
if compute_checksum
|
111
|
+
else "",
|
109
112
|
run_id=trainer.hparams.id,
|
110
113
|
name=trainer.hparams.full_name,
|
111
114
|
project=trainer.hparams.project,
|
nshtrainer/_hf_hub.py
CHANGED
@@ -91,6 +91,9 @@ class HuggingFaceHubConfig(CallbackConfigBase):
|
|
91
91
|
|
92
92
|
@override
|
93
93
|
def create_callbacks(self, trainer_config):
|
94
|
+
if not self:
|
95
|
+
return
|
96
|
+
|
94
97
|
# Attempt to login. If it fails, we'll log a warning or error based on the configuration.
|
95
98
|
try:
|
96
99
|
api = _api(self.token)
|
@@ -1,17 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
+
import string
|
4
5
|
from abc import ABC, abstractmethod
|
6
|
+
from collections.abc import Callable
|
5
7
|
from pathlib import Path
|
6
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
|
7
9
|
|
8
10
|
import numpy as np
|
9
11
|
import torch
|
10
12
|
from lightning.pytorch import Trainer
|
11
13
|
from lightning.pytorch.callbacks import Checkpoint
|
12
|
-
from typing_extensions import
|
14
|
+
from typing_extensions import override
|
13
15
|
|
14
|
-
from ..._checkpoint.metadata import CheckpointMetadata
|
16
|
+
from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
|
15
17
|
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
16
18
|
from ..base import CallbackConfigBase
|
17
19
|
|
@@ -22,6 +24,81 @@ if TYPE_CHECKING:
|
|
22
24
|
log = logging.getLogger(__name__)
|
23
25
|
|
24
26
|
|
27
|
+
class _FormatDict(dict):
|
28
|
+
"""A dictionary that returns an empty string for missing keys when formatting."""
|
29
|
+
|
30
|
+
def __missing__(self, key):
|
31
|
+
log.debug(
|
32
|
+
f"Missing format key '{key}' in checkpoint filename, using empty string"
|
33
|
+
)
|
34
|
+
return ""
|
35
|
+
|
36
|
+
|
37
|
+
def _get_checkpoint_metadata(dirpath: Path) -> list[CheckpointMetadata]:
|
38
|
+
"""Get all checkpoint metadata from a directory."""
|
39
|
+
return [
|
40
|
+
CheckpointMetadata.from_file(p)
|
41
|
+
for p in dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
42
|
+
if p.is_file() and not p.is_symlink()
|
43
|
+
]
|
44
|
+
|
45
|
+
|
46
|
+
def _sort_checkpoint_metadata(
|
47
|
+
metas: list[CheckpointMetadata],
|
48
|
+
key_fn: Callable[[CheckpointMetadata], Any],
|
49
|
+
reverse: bool = False,
|
50
|
+
) -> list[CheckpointMetadata]:
|
51
|
+
"""Sort checkpoint metadata by the given key function."""
|
52
|
+
return sorted(metas, key=key_fn, reverse=reverse)
|
53
|
+
|
54
|
+
|
55
|
+
def _remove_checkpoints(
|
56
|
+
trainer: Trainer,
|
57
|
+
dirpath: Path,
|
58
|
+
metas_to_remove: list[CheckpointMetadata],
|
59
|
+
) -> None:
|
60
|
+
"""Remove checkpoint files and their metadata."""
|
61
|
+
for meta in metas_to_remove:
|
62
|
+
ckpt_path = dirpath / meta.checkpoint_filename
|
63
|
+
if not ckpt_path.exists():
|
64
|
+
log.warning(
|
65
|
+
f"Checkpoint file not found: {ckpt_path}\n"
|
66
|
+
"Skipping removal of the checkpoint metadata."
|
67
|
+
)
|
68
|
+
continue
|
69
|
+
|
70
|
+
remove_checkpoint(trainer, ckpt_path, metadata=True)
|
71
|
+
log.debug(f"Removed checkpoint: {ckpt_path}")
|
72
|
+
|
73
|
+
|
74
|
+
def _update_symlink(
|
75
|
+
dirpath: Path,
|
76
|
+
symlink_path: Path | None,
|
77
|
+
sort_key_fn: Callable[[CheckpointMetadata], Any],
|
78
|
+
sort_reverse: bool,
|
79
|
+
) -> None:
|
80
|
+
"""Update symlink to point to the best checkpoint."""
|
81
|
+
if symlink_path is None:
|
82
|
+
return
|
83
|
+
|
84
|
+
# Get all checkpoint metadata after any removals
|
85
|
+
remaining_metas = _get_checkpoint_metadata(dirpath)
|
86
|
+
|
87
|
+
if remaining_metas:
|
88
|
+
# Sort by the key function
|
89
|
+
remaining_metas = _sort_checkpoint_metadata(
|
90
|
+
remaining_metas, sort_key_fn, sort_reverse
|
91
|
+
)
|
92
|
+
|
93
|
+
# Link to the best checkpoint
|
94
|
+
best_meta = remaining_metas[0]
|
95
|
+
best_filepath = dirpath / best_meta.checkpoint_filename
|
96
|
+
link_checkpoint(best_filepath, symlink_path, metadata=True)
|
97
|
+
log.debug(f"Updated symlink {symlink_path.name} -> {best_filepath.name}")
|
98
|
+
else:
|
99
|
+
log.warning(f"No checkpoints found in {dirpath} to create symlink.")
|
100
|
+
|
101
|
+
|
25
102
|
class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
26
103
|
dirpath: str | Path | None = None
|
27
104
|
"""Directory path to save the checkpoint file."""
|
@@ -95,35 +172,27 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
95
172
|
def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
|
96
173
|
if (filename := self.config.filename) is None:
|
97
174
|
filename = self.default_filename()
|
98
|
-
filename = filename.format(**current_metrics)
|
99
|
-
return self.dirpath / f"{filename}{self.extension()}"
|
100
|
-
|
101
|
-
def remove_old_checkpoints(self, trainer: Trainer):
|
102
|
-
if (topk := self.config.topk) == "all":
|
103
|
-
return
|
104
175
|
|
105
|
-
#
|
106
|
-
|
107
|
-
|
108
|
-
for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
109
|
-
if p.is_file() and not p.is_symlink()
|
176
|
+
# Extract all field names from the format string
|
177
|
+
field_names = [
|
178
|
+
fname for _, fname, _, _ in string.Formatter().parse(filename) if fname
|
110
179
|
]
|
111
180
|
|
112
|
-
#
|
113
|
-
|
181
|
+
# Filter current_metrics to only include keys that are in the format string
|
182
|
+
format_dict = {k: v for k, v in current_metrics.items() if k in field_names}
|
114
183
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
184
|
+
try:
|
185
|
+
formatted_filename = filename.format(**format_dict)
|
186
|
+
except KeyError as e:
|
187
|
+
log.warning(
|
188
|
+
f"Missing key {e} in {filename=} with {format_dict=}. Using default values."
|
189
|
+
)
|
190
|
+
# Provide a simple fallback for missing keys
|
191
|
+
formatted_filename = string.Formatter().vformat(
|
192
|
+
filename, (), _FormatDict(format_dict)
|
193
|
+
)
|
124
194
|
|
125
|
-
|
126
|
-
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
195
|
+
return self.dirpath / f"{formatted_filename}{self.extension()}"
|
127
196
|
|
128
197
|
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
129
198
|
current_metrics: dict[str, Any] = {
|
@@ -142,9 +211,22 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
142
211
|
|
143
212
|
current_metrics[name] = value
|
144
213
|
|
214
|
+
log.debug(
|
215
|
+
f"Current metrics: {current_metrics}, {trainer.callback_metrics=}, {trainer.logged_metrics=}"
|
216
|
+
)
|
145
217
|
return current_metrics
|
146
218
|
|
147
219
|
def save_checkpoints(self, trainer: Trainer):
|
220
|
+
log.debug(
|
221
|
+
f"{type(self).__name__}.save_checkpoints() called at {trainer.current_epoch=}, {trainer.global_step=}"
|
222
|
+
)
|
223
|
+
# Also print out the current stack trace for debugging
|
224
|
+
if log.isEnabledFor(logging.DEBUG):
|
225
|
+
import traceback
|
226
|
+
|
227
|
+
stack = traceback.extract_stack()
|
228
|
+
log.debug(f"Stack trace: {''.join(traceback.format_list(stack))}")
|
229
|
+
|
148
230
|
if self._should_skip_saving_checkpoint(trainer):
|
149
231
|
return
|
150
232
|
|
@@ -156,22 +238,73 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
156
238
|
f"but got {type(trainer).__name__}"
|
157
239
|
)
|
158
240
|
|
159
|
-
|
160
|
-
filepath = self.resolve_checkpoint_path(
|
161
|
-
|
241
|
+
current_metrics = self.current_metrics(trainer)
|
242
|
+
filepath = self.resolve_checkpoint_path(current_metrics)
|
243
|
+
|
244
|
+
# Get all existing checkpoint metadata
|
245
|
+
existing_metas = _get_checkpoint_metadata(self.dirpath)
|
246
|
+
|
247
|
+
# Determine which checkpoints to remove
|
248
|
+
to_remove: list[CheckpointMetadata] = []
|
249
|
+
should_save = True
|
250
|
+
|
251
|
+
# Check if we should save this checkpoint
|
252
|
+
if (topk := self.config.topk) != "all" and len(existing_metas) >= topk:
|
253
|
+
# Generate hypothetical metadata for the current checkpoint
|
254
|
+
hypothetical_meta = _generate_checkpoint_metadata(
|
255
|
+
trainer=trainer,
|
256
|
+
checkpoint_path=filepath,
|
257
|
+
metadata_path=filepath.with_suffix(CheckpointMetadata.PATH_SUFFIX),
|
258
|
+
compute_checksum=False,
|
259
|
+
)
|
260
|
+
|
261
|
+
# Add the hypothetical metadata to the list and sort
|
262
|
+
metas = _sort_checkpoint_metadata(
|
263
|
+
[*existing_metas, hypothetical_meta],
|
264
|
+
self.topk_sort_key,
|
265
|
+
self.topk_sort_reverse(),
|
266
|
+
)
|
267
|
+
|
268
|
+
# If the hypothetical metadata is not in the top-k, skip saving
|
269
|
+
if hypothetical_meta not in metas[:topk]:
|
270
|
+
log.debug(
|
271
|
+
f"Skipping checkpoint save: would not make top {topk} "
|
272
|
+
f"based on {self.topk_sort_key.__name__}"
|
273
|
+
)
|
274
|
+
should_save = False
|
275
|
+
else:
|
276
|
+
# Determine which existing checkpoints to remove
|
277
|
+
to_remove = metas[topk:]
|
278
|
+
assert hypothetical_meta not in to_remove, (
|
279
|
+
"Hypothetical metadata should not be in the to_remove list."
|
280
|
+
)
|
281
|
+
log.debug(
|
282
|
+
f"Removing checkpoints: {[meta.checkpoint_filename for meta in to_remove]} "
|
283
|
+
f"and saving the new checkpoint: {hypothetical_meta.checkpoint_filename}"
|
284
|
+
)
|
162
285
|
|
163
|
-
if
|
164
|
-
|
165
|
-
|
286
|
+
# Only save if it would make it into the top-k
|
287
|
+
if should_save:
|
288
|
+
# Save the new checkpoint
|
289
|
+
trainer.save_checkpoint(
|
290
|
+
filepath,
|
291
|
+
weights_only=self.config.save_weights_only,
|
292
|
+
)
|
166
293
|
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
294
|
+
if trainer.is_global_zero:
|
295
|
+
# Remove old checkpoints that should be deleted
|
296
|
+
if to_remove:
|
297
|
+
_remove_checkpoints(trainer, self.dirpath, to_remove)
|
298
|
+
|
299
|
+
# Update the symlink to point to the best checkpoint
|
300
|
+
_update_symlink(
|
301
|
+
self.dirpath,
|
302
|
+
self.symlink_path(),
|
303
|
+
self.topk_sort_key,
|
304
|
+
self.topk_sort_reverse(),
|
305
|
+
)
|
172
306
|
|
173
|
-
# Barrier to ensure all processes have
|
174
|
-
# deleted the old checkpoints, and created the symlink before continuing
|
307
|
+
# Barrier to ensure all processes have completed checkpoint operations
|
175
308
|
trainer.strategy.barrier()
|
176
309
|
|
177
310
|
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
@@ -1,12 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import logging
|
3
4
|
from typing import Literal
|
4
5
|
|
5
6
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
6
|
-
from typing_extensions import final
|
7
|
+
from typing_extensions import final, override
|
7
8
|
|
8
9
|
from .base import CallbackConfigBase, callback_registry
|
9
10
|
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
10
13
|
|
11
14
|
@final
|
12
15
|
@callback_registry.register
|
@@ -28,7 +31,12 @@ class LearningRateMonitorConfig(CallbackConfigBase):
|
|
28
31
|
Option to also log the weight decay values of the optimizer. Defaults to False.
|
29
32
|
"""
|
30
33
|
|
34
|
+
@override
|
31
35
|
def create_callbacks(self, trainer_config):
|
36
|
+
if not list(trainer_config.enabled_loggers()):
|
37
|
+
log.warning("No loggers enabled. LearningRateMonitor will not be used.")
|
38
|
+
return
|
39
|
+
|
32
40
|
yield LearningRateMonitor(
|
33
41
|
logging_interval=self.logging_interval,
|
34
42
|
log_momentum=self.log_momentum,
|
nshtrainer/configs/__init__.py
CHANGED
@@ -5,7 +5,6 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer import MetricConfig as MetricConfig
|
6
6
|
from nshtrainer import TrainerConfig as TrainerConfig
|
7
7
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
8
|
-
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
9
8
|
from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
|
10
9
|
from nshtrainer._hf_hub import (
|
11
10
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
@@ -126,9 +125,9 @@ from nshtrainer.trainer._config import (
|
|
126
125
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
127
126
|
)
|
128
127
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
128
|
+
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
129
129
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
130
130
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
131
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
132
131
|
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
133
132
|
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
134
133
|
from nshtrainer.trainer.accelerator import (
|
@@ -227,7 +226,6 @@ from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
|
227
226
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
228
227
|
|
229
228
|
from . import _checkpoint as _checkpoint
|
230
|
-
from . import _directory as _directory
|
231
229
|
from . import _hf_hub as _hf_hub
|
232
230
|
from . import callbacks as callbacks
|
233
231
|
from . import loggers as loggers
|
@@ -338,7 +336,6 @@ __all__ = [
|
|
338
336
|
"RpropConfig",
|
339
337
|
"SGDConfig",
|
340
338
|
"SLURMEnvironmentPlugin",
|
341
|
-
"SanityCheckingConfig",
|
342
339
|
"SharedParametersCallbackConfig",
|
343
340
|
"SiLUNonlinearityConfig",
|
344
341
|
"SigmoidNonlinearityConfig",
|
@@ -367,7 +364,6 @@ __all__ = [
|
|
367
364
|
"XLAEnvironmentPlugin",
|
368
365
|
"XLAPluginConfig",
|
369
366
|
"_checkpoint",
|
370
|
-
"_directory",
|
371
367
|
"_hf_hub",
|
372
368
|
"accelerator_registry",
|
373
369
|
"callback_registry",
|
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
|
|
22
22
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
23
23
|
)
|
24
24
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
25
|
+
from nshtrainer.trainer._config import (
|
26
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
27
|
+
)
|
25
28
|
from nshtrainer.trainer._config import (
|
26
29
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
27
30
|
)
|
@@ -51,7 +54,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
51
54
|
from nshtrainer.trainer._config import (
|
52
55
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
53
56
|
)
|
54
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
55
57
|
from nshtrainer.trainer._config import (
|
56
58
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
57
59
|
)
|
@@ -152,6 +154,7 @@ __all__ = [
|
|
152
154
|
"DebugFlagCallbackConfig",
|
153
155
|
"DeepSpeedPluginConfig",
|
154
156
|
"DirectoryConfig",
|
157
|
+
"DirectorySetupCallbackConfig",
|
155
158
|
"DistributedPredictionWriterConfig",
|
156
159
|
"DoublePrecisionPluginConfig",
|
157
160
|
"EarlyStoppingCallbackConfig",
|
@@ -180,7 +183,6 @@ __all__ = [
|
|
180
183
|
"ProfilerConfig",
|
181
184
|
"RLPSanityChecksCallbackConfig",
|
182
185
|
"SLURMEnvironmentPlugin",
|
183
|
-
"SanityCheckingConfig",
|
184
186
|
"SharedParametersCallbackConfig",
|
185
187
|
"StrategyConfig",
|
186
188
|
"StrategyConfigBase",
|
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
|
|
18
18
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
19
19
|
)
|
20
20
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
21
|
+
from nshtrainer.trainer._config import (
|
22
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
23
|
+
)
|
21
24
|
from nshtrainer.trainer._config import (
|
22
25
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
23
26
|
)
|
@@ -48,7 +51,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
48
51
|
from nshtrainer.trainer._config import (
|
49
52
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
50
53
|
)
|
51
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
52
54
|
from nshtrainer.trainer._config import (
|
53
55
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
54
56
|
)
|
@@ -70,6 +72,7 @@ __all__ = [
|
|
70
72
|
"CheckpointSavingConfig",
|
71
73
|
"DebugFlagCallbackConfig",
|
72
74
|
"DirectoryConfig",
|
75
|
+
"DirectorySetupCallbackConfig",
|
73
76
|
"EarlyStoppingCallbackConfig",
|
74
77
|
"EnvironmentConfig",
|
75
78
|
"GradientClippingConfig",
|
@@ -86,7 +89,6 @@ __all__ = [
|
|
86
89
|
"PluginConfig",
|
87
90
|
"ProfilerConfig",
|
88
91
|
"RLPSanityChecksCallbackConfig",
|
89
|
-
"SanityCheckingConfig",
|
90
92
|
"SharedParametersCallbackConfig",
|
91
93
|
"StrategyConfig",
|
92
94
|
"TensorboardLoggerConfig",
|
nshtrainer/trainer/_config.py
CHANGED
@@ -26,7 +26,6 @@ from lightning.pytorch.profilers import Profiler
|
|
26
26
|
from lightning.pytorch.strategies.strategy import Strategy
|
27
27
|
from typing_extensions import TypeAliasType, TypedDict, override
|
28
28
|
|
29
|
-
from .._directory import DirectoryConfig
|
30
29
|
from .._hf_hub import HuggingFaceHubConfig
|
31
30
|
from ..callbacks import (
|
32
31
|
BestCheckpointCallbackConfig,
|
@@ -38,6 +37,7 @@ from ..callbacks import (
|
|
38
37
|
)
|
39
38
|
from ..callbacks.base import CallbackConfigBase
|
40
39
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
40
|
+
from ..callbacks.directory_setup import DirectorySetupCallbackConfig
|
41
41
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
42
42
|
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
43
43
|
from ..callbacks.metric_validation import MetricValidationCallbackConfig
|
@@ -352,19 +352,74 @@ class LightningTrainerKwargs(TypedDict, total=False):
|
|
352
352
|
"""
|
353
353
|
|
354
354
|
|
355
|
-
|
356
|
-
|
355
|
+
DEFAULT_LOGDIR_BASENAME = "nshtrainer_logs"
|
356
|
+
"""Default base name for the log directory."""
|
357
|
+
|
358
|
+
|
359
|
+
class DirectoryConfig(C.Config):
|
360
|
+
project_root: Path | None = None
|
357
361
|
"""
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
Valid values are: "disable", "warn", "error".
|
362
|
+
Root directory for this project.
|
363
|
+
|
364
|
+
This isn't specific to the current run; it is the parent directory of all runs.
|
362
365
|
"""
|
363
366
|
|
367
|
+
logdir_basename: str = DEFAULT_LOGDIR_BASENAME
|
368
|
+
"""Base name for the log directory."""
|
369
|
+
|
370
|
+
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
371
|
+
"""Configuration for the directory setup PyTorch Lightning callback."""
|
372
|
+
|
373
|
+
def resolve_run_root_directory(self, run_id: str) -> Path:
|
374
|
+
if (project_root_dir := self.project_root) is None:
|
375
|
+
project_root_dir = Path.cwd()
|
376
|
+
|
377
|
+
# The default base dir is $CWD/{logdir_basename}/{id}/
|
378
|
+
base_dir = project_root_dir / self.logdir_basename
|
379
|
+
base_dir.mkdir(exist_ok=True)
|
380
|
+
|
381
|
+
# Add a .gitignore file to the {logdir_basename} directory
|
382
|
+
# which will ignore all files except for the .gitignore file itself
|
383
|
+
gitignore_path = base_dir / ".gitignore"
|
384
|
+
if not gitignore_path.exists():
|
385
|
+
gitignore_path.touch()
|
386
|
+
gitignore_path.write_text("*\n")
|
387
|
+
|
388
|
+
base_dir = base_dir / run_id
|
389
|
+
base_dir.mkdir(exist_ok=True)
|
390
|
+
|
391
|
+
return base_dir
|
392
|
+
|
393
|
+
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
394
|
+
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
395
|
+
if (subdir := getattr(self, subdirectory, None)) is not None:
|
396
|
+
assert isinstance(subdir, Path), (
|
397
|
+
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
398
|
+
)
|
399
|
+
return subdir
|
400
|
+
|
401
|
+
dir = self.resolve_run_root_directory(run_id)
|
402
|
+
dir = dir / subdirectory
|
403
|
+
dir.mkdir(exist_ok=True)
|
404
|
+
return dir
|
405
|
+
|
406
|
+
def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
|
407
|
+
if (log_dir := logger.log_dir) is not None:
|
408
|
+
return log_dir
|
409
|
+
|
410
|
+
# Save to {logdir_basename}/{id}/log/{logger name}
|
411
|
+
log_dir = self.resolve_subdirectory(run_id, "log")
|
412
|
+
log_dir = log_dir / logger.resolve_logger_dirname()
|
413
|
+
# ^ NOTE: Logger must have a `name` attribute, as this is
|
414
|
+
# the discriminator for the logger registry
|
415
|
+
log_dir.mkdir(exist_ok=True)
|
416
|
+
|
417
|
+
return log_dir
|
418
|
+
|
364
419
|
|
365
420
|
class TrainerConfig(C.Config):
|
366
421
|
# region Active Run Configuration
|
367
|
-
id: str
|
422
|
+
id: Annotated[str, C.AllowMissing()] = C.MISSING
|
368
423
|
"""ID of the run."""
|
369
424
|
name: list[str] = []
|
370
425
|
"""Run name in parts. Full name is constructed by joining the parts with spaces."""
|
@@ -393,39 +448,6 @@ class TrainerConfig(C.Config):
|
|
393
448
|
|
394
449
|
directory: DirectoryConfig = DirectoryConfig()
|
395
450
|
"""Directory configuration options."""
|
396
|
-
|
397
|
-
_rng: ClassVar[np.random.Generator | None] = None
|
398
|
-
|
399
|
-
@classmethod
|
400
|
-
def generate_id(cls, *, length: int = 8) -> str:
|
401
|
-
"""
|
402
|
-
Generate a random ID of specified length.
|
403
|
-
|
404
|
-
"""
|
405
|
-
if (rng := cls._rng) is None:
|
406
|
-
rng = np.random.default_rng()
|
407
|
-
|
408
|
-
alphabet = list(string.ascii_lowercase + string.digits)
|
409
|
-
|
410
|
-
id = "".join(rng.choice(alphabet) for _ in range(length))
|
411
|
-
return id
|
412
|
-
|
413
|
-
@classmethod
|
414
|
-
def set_seed(cls, seed: int | None = None) -> None:
|
415
|
-
"""
|
416
|
-
Set the seed for the random number generator.
|
417
|
-
|
418
|
-
Args:
|
419
|
-
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
420
|
-
|
421
|
-
Returns:
|
422
|
-
None
|
423
|
-
"""
|
424
|
-
if seed is None:
|
425
|
-
seed = int(time.time() * 1000)
|
426
|
-
log.critical(f"Seeding {cls.__name__} with seed {seed}")
|
427
|
-
cls._rng = np.random.default_rng(seed)
|
428
|
-
|
429
451
|
# endregion
|
430
452
|
|
431
453
|
primary_metric: MetricConfig | None = None
|
@@ -695,8 +717,9 @@ class TrainerConfig(C.Config):
|
|
695
717
|
|
696
718
|
auto_set_default_root_dir: bool = True
|
697
719
|
"""If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
|
698
|
-
save_checkpoint_metadata:
|
699
|
-
"""
|
720
|
+
save_checkpoint_metadata: Literal[True] = True
|
721
|
+
"""Will save additional metadata whenever a checkpoint is saved.
|
722
|
+
This is a core feature of nshtrainer and cannot be disabled."""
|
700
723
|
auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
|
701
724
|
"""If enabled, will automatically set the debug flag to True if:
|
702
725
|
- The trainer is running in fast_dev_run mode.
|
@@ -755,40 +778,40 @@ class TrainerConfig(C.Config):
|
|
755
778
|
None,
|
756
779
|
)
|
757
780
|
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
yield self.lr_monitor
|
763
|
-
yield from (
|
764
|
-
logger_config
|
765
|
-
for logger_config in self.enabled_loggers()
|
766
|
-
if logger_config is not None
|
767
|
-
and isinstance(logger_config, CallbackConfigBase)
|
768
|
-
)
|
769
|
-
yield self.log_epoch
|
770
|
-
yield self.log_norms
|
771
|
-
yield self.hf_hub
|
772
|
-
yield self.shared_parameters
|
773
|
-
yield self.reduce_lr_on_plateau_sanity_checking
|
774
|
-
yield self.auto_set_debug_flag
|
775
|
-
yield self.auto_validate_metrics
|
776
|
-
yield from self.callbacks
|
781
|
+
# region Helper Methods
|
782
|
+
def id_(self, value: str):
|
783
|
+
"""
|
784
|
+
Set the id for the trainer configuration in-place.
|
777
785
|
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
786
|
+
Parameters
|
787
|
+
----------
|
788
|
+
value : str
|
789
|
+
The id value to set
|
782
790
|
|
783
|
-
|
784
|
-
|
791
|
+
Returns
|
792
|
+
-------
|
793
|
+
self
|
794
|
+
Returns self for method chaining
|
795
|
+
"""
|
796
|
+
self.id = value
|
797
|
+
return self
|
785
798
|
|
786
|
-
def
|
787
|
-
|
788
|
-
|
789
|
-
|
799
|
+
def with_id(self, value: str):
|
800
|
+
"""
|
801
|
+
Create a copy of the current configuration with an updated id.
|
802
|
+
|
803
|
+
Parameters
|
804
|
+
----------
|
805
|
+
value : str
|
806
|
+
The id value to set
|
807
|
+
|
808
|
+
Returns
|
809
|
+
-------
|
810
|
+
TrainerConfig
|
811
|
+
A new instance of the configuration with the updated id
|
812
|
+
"""
|
813
|
+
return copy.deepcopy(self).id_(value)
|
790
814
|
|
791
|
-
# region Helper Methods
|
792
815
|
def fast_dev_run_(self, value: int | bool = True, /):
|
793
816
|
"""
|
794
817
|
Enables fast_dev_run mode for the trainer.
|
@@ -831,6 +854,349 @@ class TrainerConfig(C.Config):
|
|
831
854
|
"""
|
832
855
|
return copy.deepcopy(self).project_root_(project_root)
|
833
856
|
|
857
|
+
def name_(self, *parts: str):
|
858
|
+
"""
|
859
|
+
Set the name for the trainer configuration in-place.
|
860
|
+
|
861
|
+
Parameters
|
862
|
+
----------
|
863
|
+
*parts : str
|
864
|
+
The parts of the name to set. Will be joined with spaces.
|
865
|
+
|
866
|
+
Returns
|
867
|
+
-------
|
868
|
+
self
|
869
|
+
Returns self for method chaining
|
870
|
+
"""
|
871
|
+
self.name = list(parts)
|
872
|
+
return self
|
873
|
+
|
874
|
+
def with_name(self, *parts: str):
|
875
|
+
"""
|
876
|
+
Create a copy of the current configuration with an updated name.
|
877
|
+
|
878
|
+
Parameters
|
879
|
+
----------
|
880
|
+
*parts : str
|
881
|
+
The parts of the name to set. Will be joined with spaces.
|
882
|
+
|
883
|
+
Returns
|
884
|
+
-------
|
885
|
+
TrainerConfig
|
886
|
+
A new instance of the configuration with the updated name
|
887
|
+
"""
|
888
|
+
return copy.deepcopy(self).name_(*parts)
|
889
|
+
|
890
|
+
def project_(self, project: str | None):
|
891
|
+
"""
|
892
|
+
Set the project name for the trainer configuration in-place.
|
893
|
+
|
894
|
+
Parameters
|
895
|
+
----------
|
896
|
+
project : str | None
|
897
|
+
The project name to set
|
898
|
+
|
899
|
+
Returns
|
900
|
+
-------
|
901
|
+
self
|
902
|
+
Returns self for method chaining
|
903
|
+
"""
|
904
|
+
self.project = project
|
905
|
+
return self
|
906
|
+
|
907
|
+
def with_project(self, project: str | None):
|
908
|
+
"""
|
909
|
+
Create a copy of the current configuration with an updated project name.
|
910
|
+
|
911
|
+
Parameters
|
912
|
+
----------
|
913
|
+
project : str | None
|
914
|
+
The project name to set
|
915
|
+
|
916
|
+
Returns
|
917
|
+
-------
|
918
|
+
TrainerConfig
|
919
|
+
A new instance of the configuration with the updated project name
|
920
|
+
"""
|
921
|
+
return copy.deepcopy(self).project_(project)
|
922
|
+
|
923
|
+
def tags_(self, *tags: str):
|
924
|
+
"""
|
925
|
+
Set the tags for the trainer configuration in-place.
|
926
|
+
|
927
|
+
Parameters
|
928
|
+
----------
|
929
|
+
*tags : str
|
930
|
+
The tags to set
|
931
|
+
|
932
|
+
Returns
|
933
|
+
-------
|
934
|
+
self
|
935
|
+
Returns self for method chaining
|
936
|
+
"""
|
937
|
+
self.tags = list(tags)
|
938
|
+
return self
|
939
|
+
|
940
|
+
def with_tags(self, *tags: str):
|
941
|
+
"""
|
942
|
+
Create a copy of the current configuration with updated tags.
|
943
|
+
|
944
|
+
Parameters
|
945
|
+
----------
|
946
|
+
*tags : str
|
947
|
+
The tags to set
|
948
|
+
|
949
|
+
Returns
|
950
|
+
-------
|
951
|
+
TrainerConfig
|
952
|
+
A new instance of the configuration with the updated tags
|
953
|
+
"""
|
954
|
+
return copy.deepcopy(self).tags_(*tags)
|
955
|
+
|
956
|
+
def add_tags_(self, *tags: str):
|
957
|
+
"""
|
958
|
+
Add tags to the trainer configuration in-place.
|
959
|
+
|
960
|
+
Parameters
|
961
|
+
----------
|
962
|
+
*tags : str
|
963
|
+
The tags to add
|
964
|
+
|
965
|
+
Returns
|
966
|
+
-------
|
967
|
+
self
|
968
|
+
Returns self for method chaining
|
969
|
+
"""
|
970
|
+
self.tags.extend(tags)
|
971
|
+
return self
|
972
|
+
|
973
|
+
def with_added_tags(self, *tags: str):
|
974
|
+
"""
|
975
|
+
Create a copy of the current configuration with additional tags.
|
976
|
+
|
977
|
+
Parameters
|
978
|
+
----------
|
979
|
+
*tags : str
|
980
|
+
The tags to add
|
981
|
+
|
982
|
+
Returns
|
983
|
+
-------
|
984
|
+
TrainerConfig
|
985
|
+
A new instance of the configuration with the additional tags
|
986
|
+
"""
|
987
|
+
return copy.deepcopy(self).add_tags_(*tags)
|
988
|
+
|
989
|
+
def notes_(self, *notes: str):
|
990
|
+
"""
|
991
|
+
Set the notes for the trainer configuration in-place.
|
992
|
+
|
993
|
+
Parameters
|
994
|
+
----------
|
995
|
+
*notes : str
|
996
|
+
The notes to set
|
997
|
+
|
998
|
+
Returns
|
999
|
+
-------
|
1000
|
+
self
|
1001
|
+
Returns self for method chaining
|
1002
|
+
"""
|
1003
|
+
self.notes = list(notes)
|
1004
|
+
return self
|
1005
|
+
|
1006
|
+
def with_notes(self, *notes: str):
|
1007
|
+
"""
|
1008
|
+
Create a copy of the current configuration with updated notes.
|
1009
|
+
|
1010
|
+
Parameters
|
1011
|
+
----------
|
1012
|
+
*notes : str
|
1013
|
+
The notes to set
|
1014
|
+
|
1015
|
+
Returns
|
1016
|
+
-------
|
1017
|
+
TrainerConfig
|
1018
|
+
A new instance of the configuration with the updated notes
|
1019
|
+
"""
|
1020
|
+
return copy.deepcopy(self).notes_(*notes)
|
1021
|
+
|
1022
|
+
def add_notes_(self, *notes: str):
|
1023
|
+
"""
|
1024
|
+
Add notes to the trainer configuration in-place.
|
1025
|
+
|
1026
|
+
Parameters
|
1027
|
+
----------
|
1028
|
+
*notes : str
|
1029
|
+
The notes to add
|
1030
|
+
|
1031
|
+
Returns
|
1032
|
+
-------
|
1033
|
+
self
|
1034
|
+
Returns self for method chaining
|
1035
|
+
"""
|
1036
|
+
self.notes.extend(notes)
|
1037
|
+
return self
|
1038
|
+
|
1039
|
+
def with_added_notes(self, *notes: str):
|
1040
|
+
"""
|
1041
|
+
Create a copy of the current configuration with additional notes.
|
1042
|
+
|
1043
|
+
Parameters
|
1044
|
+
----------
|
1045
|
+
*notes : str
|
1046
|
+
The notes to add
|
1047
|
+
|
1048
|
+
Returns
|
1049
|
+
-------
|
1050
|
+
TrainerConfig
|
1051
|
+
A new instance of the configuration with the additional notes
|
1052
|
+
"""
|
1053
|
+
return copy.deepcopy(self).add_notes_(*notes)
|
1054
|
+
|
1055
|
+
def meta_(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
|
1056
|
+
"""
|
1057
|
+
Update the `meta` dictionary in-place with the provided key-value pairs.
|
1058
|
+
|
1059
|
+
This method allows updating the meta information associated with the trainer
|
1060
|
+
configuration by either passing a dictionary or keyword arguments.
|
1061
|
+
|
1062
|
+
Parameters
|
1063
|
+
----------
|
1064
|
+
meta : dict[str, Any] | None, optional
|
1065
|
+
A dictionary containing meta information to be added, by default None
|
1066
|
+
**kwargs : Any
|
1067
|
+
Additional key-value pairs to be added to the meta dictionary
|
1068
|
+
|
1069
|
+
Returns
|
1070
|
+
-------
|
1071
|
+
self
|
1072
|
+
Returns self for method chaining
|
1073
|
+
"""
|
1074
|
+
if meta is not None:
|
1075
|
+
self.meta.update(meta)
|
1076
|
+
self.meta.update(kwargs)
|
1077
|
+
return self
|
1078
|
+
|
1079
|
+
def with_meta(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
|
1080
|
+
"""
|
1081
|
+
Create a copy of the current configuration with updated meta information.
|
1082
|
+
|
1083
|
+
This method is similar to `meta_`, but it returns a new instance of the configuration
|
1084
|
+
with the updated meta information instead of modifying the current instance.
|
1085
|
+
|
1086
|
+
Parameters
|
1087
|
+
----------
|
1088
|
+
meta : dict[str, Any] | None, optional
|
1089
|
+
A dictionary containing meta information to be added, by default None
|
1090
|
+
**kwargs : Any
|
1091
|
+
Additional key-value pairs to be added to the meta dictionary
|
1092
|
+
|
1093
|
+
Returns
|
1094
|
+
-------
|
1095
|
+
TrainerConfig
|
1096
|
+
A new instance of the configuration with updated meta information
|
1097
|
+
"""
|
1098
|
+
|
1099
|
+
return self.model_copy(deep=True).meta_(meta, **kwargs)
|
1100
|
+
|
1101
|
+
def debug_(self, value: bool = True):
|
1102
|
+
"""
|
1103
|
+
Set the debug flag for the trainer configuration in-place.
|
1104
|
+
|
1105
|
+
Parameters
|
1106
|
+
----------
|
1107
|
+
value : bool, optional
|
1108
|
+
The debug flag value to set, by default True
|
1109
|
+
|
1110
|
+
Returns
|
1111
|
+
-------
|
1112
|
+
self
|
1113
|
+
Returns self for method chaining
|
1114
|
+
"""
|
1115
|
+
self.debug = value
|
1116
|
+
return self
|
1117
|
+
|
1118
|
+
def with_debug(self, value: bool = True):
|
1119
|
+
"""
|
1120
|
+
Create a copy of the current configuration with an updated debug flag.
|
1121
|
+
|
1122
|
+
Parameters
|
1123
|
+
----------
|
1124
|
+
value : bool, optional
|
1125
|
+
The debug flag value to set, by default True
|
1126
|
+
|
1127
|
+
Returns
|
1128
|
+
-------
|
1129
|
+
TrainerConfig
|
1130
|
+
A new instance of the configuration with the updated debug flag
|
1131
|
+
"""
|
1132
|
+
return copy.deepcopy(self).debug_(value)
|
1133
|
+
|
1134
|
+
def ckpt_path_(self, path: Literal["none"] | str | Path | None):
|
1135
|
+
"""
|
1136
|
+
Set the checkpoint path for the trainer configuration in-place.
|
1137
|
+
|
1138
|
+
Parameters
|
1139
|
+
----------
|
1140
|
+
path : Literal["none"] | str | Path | None
|
1141
|
+
The checkpoint path to set
|
1142
|
+
|
1143
|
+
Returns
|
1144
|
+
-------
|
1145
|
+
self
|
1146
|
+
Returns self for method chaining
|
1147
|
+
"""
|
1148
|
+
self.ckpt_path = path
|
1149
|
+
return self
|
1150
|
+
|
1151
|
+
def with_ckpt_path(self, path: Literal["none"] | str | Path | None):
|
1152
|
+
"""
|
1153
|
+
Create a copy of the current configuration with an updated checkpoint path.
|
1154
|
+
|
1155
|
+
Parameters
|
1156
|
+
----------
|
1157
|
+
path : Literal["none"] | str | Path | None
|
1158
|
+
The checkpoint path to set
|
1159
|
+
|
1160
|
+
Returns
|
1161
|
+
-------
|
1162
|
+
TrainerConfig
|
1163
|
+
A new instance of the configuration with the updated checkpoint path
|
1164
|
+
"""
|
1165
|
+
return copy.deepcopy(self).ckpt_path_(path)
|
1166
|
+
|
1167
|
+
def barebones_(self, value: bool = True):
|
1168
|
+
"""
|
1169
|
+
Set the barebones flag for the trainer configuration in-place.
|
1170
|
+
|
1171
|
+
Parameters
|
1172
|
+
----------
|
1173
|
+
value : bool, optional
|
1174
|
+
The barebones flag value to set, by default True
|
1175
|
+
|
1176
|
+
Returns
|
1177
|
+
-------
|
1178
|
+
self
|
1179
|
+
Returns self for method chaining
|
1180
|
+
"""
|
1181
|
+
self.barebones = value
|
1182
|
+
return self
|
1183
|
+
|
1184
|
+
def with_barebones(self, value: bool = True):
|
1185
|
+
"""
|
1186
|
+
Create a copy of the current configuration with an updated barebones flag.
|
1187
|
+
|
1188
|
+
Parameters
|
1189
|
+
----------
|
1190
|
+
value : bool, optional
|
1191
|
+
The barebones flag value to set, by default True
|
1192
|
+
|
1193
|
+
Returns
|
1194
|
+
-------
|
1195
|
+
TrainerConfig
|
1196
|
+
A new instance of the configuration with the updated barebones flag
|
1197
|
+
"""
|
1198
|
+
return copy.deepcopy(self).barebones_(value)
|
1199
|
+
|
834
1200
|
def reset_run(
|
835
1201
|
self,
|
836
1202
|
*,
|
@@ -873,3 +1239,89 @@ class TrainerConfig(C.Config):
|
|
873
1239
|
return config
|
874
1240
|
|
875
1241
|
# endregion
|
1242
|
+
|
1243
|
+
# region Random ID Generation
|
1244
|
+
_rng: ClassVar[np.random.Generator | None] = None
|
1245
|
+
|
1246
|
+
@classmethod
|
1247
|
+
def generate_id(cls, *, length: int = 8) -> str:
|
1248
|
+
"""
|
1249
|
+
Generate a random ID of specified length.
|
1250
|
+
|
1251
|
+
"""
|
1252
|
+
if (rng := cls._rng) is None:
|
1253
|
+
rng = np.random.default_rng()
|
1254
|
+
|
1255
|
+
alphabet = list(string.ascii_lowercase + string.digits)
|
1256
|
+
|
1257
|
+
id = "".join(rng.choice(alphabet) for _ in range(length))
|
1258
|
+
return id
|
1259
|
+
|
1260
|
+
@classmethod
|
1261
|
+
def set_seed(cls, seed: int | None = None) -> None:
|
1262
|
+
"""
|
1263
|
+
Set the seed for the random number generator.
|
1264
|
+
|
1265
|
+
Args:
|
1266
|
+
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
1267
|
+
|
1268
|
+
Returns:
|
1269
|
+
None
|
1270
|
+
"""
|
1271
|
+
if seed is None:
|
1272
|
+
seed = int(time.time() * 1000)
|
1273
|
+
log.critical(f"Seeding {cls.__name__} with seed {seed}")
|
1274
|
+
cls._rng = np.random.default_rng(seed)
|
1275
|
+
|
1276
|
+
# endregion
|
1277
|
+
|
1278
|
+
# region Internal Methods
|
1279
|
+
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
1280
|
+
yield self.directory.setup_callback
|
1281
|
+
yield self.early_stopping
|
1282
|
+
yield self.checkpoint_saving
|
1283
|
+
yield self.lr_monitor
|
1284
|
+
yield from (
|
1285
|
+
logger_config
|
1286
|
+
for logger_config in self.enabled_loggers()
|
1287
|
+
if logger_config is not None
|
1288
|
+
and isinstance(logger_config, CallbackConfigBase)
|
1289
|
+
)
|
1290
|
+
yield self.log_epoch
|
1291
|
+
yield self.log_norms
|
1292
|
+
yield self.hf_hub
|
1293
|
+
yield self.shared_parameters
|
1294
|
+
yield self.reduce_lr_on_plateau_sanity_checking
|
1295
|
+
yield self.auto_set_debug_flag
|
1296
|
+
yield self.auto_validate_metrics
|
1297
|
+
yield from self.callbacks
|
1298
|
+
|
1299
|
+
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
1300
|
+
# Disable all loggers if barebones mode is enabled
|
1301
|
+
if self.barebones:
|
1302
|
+
return
|
1303
|
+
|
1304
|
+
yield from self.enabled_loggers()
|
1305
|
+
yield self.actsave_logger
|
1306
|
+
|
1307
|
+
def _nshtrainer_validate_before_run(self):
|
1308
|
+
# shared_parameters is not supported under barebones mode
|
1309
|
+
if self.barebones and self.shared_parameters:
|
1310
|
+
raise ValueError("shared_parameters is not supported under barebones mode")
|
1311
|
+
|
1312
|
+
if not self.save_checkpoint_metadata:
|
1313
|
+
raise ValueError(
|
1314
|
+
"save_checkpoint_metadata must be True. This is a core feature of nshtrainer and cannot be disabled."
|
1315
|
+
)
|
1316
|
+
|
1317
|
+
def _nshtrainer_set_id_if_missing(self):
|
1318
|
+
"""
|
1319
|
+
Set the ID for the configuration object if it is missing.
|
1320
|
+
"""
|
1321
|
+
if self.id is C.MISSING:
|
1322
|
+
self.id = self.generate_id()
|
1323
|
+
log.info(f"TrainerConfig's run ID is missing, setting to {self.id}.")
|
1324
|
+
else:
|
1325
|
+
log.debug(f"TrainerConfig's run ID is already set to {self.id}.")
|
1326
|
+
|
1327
|
+
# endregion
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -45,6 +45,9 @@ patch_log_hparams_function()
|
|
45
45
|
|
46
46
|
|
47
47
|
class Trainer(LightningTrainer):
|
48
|
+
profiler: Profiler
|
49
|
+
"""Profiler used for profiling the training process."""
|
50
|
+
|
48
51
|
CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
|
49
52
|
|
50
53
|
@property
|
@@ -316,6 +319,7 @@ class Trainer(LightningTrainer):
|
|
316
319
|
f"Trainer hparams must either be an instance of {hparams_cls} or a mapping. "
|
317
320
|
f"Got {type(hparams)=} instead."
|
318
321
|
)
|
322
|
+
hparams._nshtrainer_set_id_if_missing()
|
319
323
|
hparams = hparams.model_deep_validate()
|
320
324
|
hparams._nshtrainer_validate_before_run()
|
321
325
|
|
@@ -468,6 +472,11 @@ class Trainer(LightningTrainer):
|
|
468
472
|
weights_only: bool = False,
|
469
473
|
storage_options: Any | None = None,
|
470
474
|
):
|
475
|
+
assert self.hparams.save_checkpoint_metadata, (
|
476
|
+
"Checkpoint metadata is not enabled. "
|
477
|
+
"Please set `hparams.save_checkpoint_metadata=True`."
|
478
|
+
)
|
479
|
+
|
471
480
|
filepath = Path(filepath)
|
472
481
|
|
473
482
|
if self.model is None:
|
@@ -475,7 +484,7 @@ class Trainer(LightningTrainer):
|
|
475
484
|
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
|
476
485
|
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
|
477
486
|
)
|
478
|
-
with self.profiler.profile("save_checkpoint"):
|
487
|
+
with self.profiler.profile("save_checkpoint"):
|
479
488
|
checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
|
480
489
|
# Update the checkpoint for the trainer hyperparameters
|
481
490
|
checkpoint[self.CHECKPOINT_HYPER_PARAMS_KEY] = self.hparams.model_dump(
|
@@ -488,7 +497,7 @@ class Trainer(LightningTrainer):
|
|
488
497
|
|
489
498
|
# Save the checkpoint metadata
|
490
499
|
metadata_path = None
|
491
|
-
if self.
|
500
|
+
if self.is_global_zero:
|
492
501
|
# Generate the metadata and write to disk
|
493
502
|
metadata_path = write_checkpoint_metadata(self, filepath)
|
494
503
|
|
@@ -1,16 +1,15 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=RI_2B_IUWa10B6H5TAuWtE5FWX1X4ue-J4dTDaF2-lQ,1035
|
3
3
|
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=El9Ip8jGA7mAN5rAMpVfg1dfUe2dGoOOfvF1JfYJGHM,5676
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
6
|
-
nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
|
7
6
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
|
-
nshtrainer/_hf_hub.py,sha256=
|
7
|
+
nshtrainer/_hf_hub.py,sha256=OB4252GJ6AbKNCRmHVvEglvjYVMUN822BFYECABxfZU,14037
|
9
8
|
nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
|
10
9
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
10
|
nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
|
12
11
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
12
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=BjgfCXsf4Ihf1MNKkHBUwjHMLwc04PZO-2Bx-LdAazg,11010
|
14
13
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7hb96BxHanmvn7PCCbqq0E,2648
|
15
14
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
15
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
@@ -23,7 +22,7 @@ nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH
|
|
23
22
|
nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
|
24
23
|
nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
|
25
24
|
nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
|
26
|
-
nshtrainer/callbacks/lr_monitor.py,sha256=
|
25
|
+
nshtrainer/callbacks/lr_monitor.py,sha256=v45ehnwNO987087HfiOY5aIrVRbwdKMgPYRFHs1fyEE,1444
|
27
26
|
nshtrainer/callbacks/metric_validation.py,sha256=4RDr1FuNKfro-6QEtmcFqT4iNf2twmJVNk9y-8nq9bg,2882
|
28
27
|
nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
|
29
28
|
nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
|
@@ -33,10 +32,9 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
|
|
33
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
|
34
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
35
34
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
36
|
-
nshtrainer/configs/__init__.py,sha256
|
35
|
+
nshtrainer/configs/__init__.py,sha256=-yJ5Uk9VkANqfk-QnX2aynL0jSf7cJQuQNzT1GAE1x8,15684
|
37
36
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
38
37
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
39
|
-
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
40
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
41
39
|
nshtrainer/configs/callbacks/__init__.py,sha256=tP9urR73NIanyxpbi4EERsxOnGNiptbQpmsj-v53a38,4774
|
42
40
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
|
@@ -85,8 +83,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
85
83
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
86
84
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
87
85
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
88
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
89
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
86
|
+
nshtrainer/configs/trainer/__init__.py,sha256=DM2PlB4WRDZ_dqEeW91LbKRFa4sIF_pETU0T9GYJ5-g,8073
|
87
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=z5UpuXktBanLOYNkkbgbbHE06iQtcSuAKTpnx2TLmCo,3850
|
90
88
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
91
89
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
92
90
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
|
@@ -135,7 +133,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
|
|
135
133
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
136
134
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
137
135
|
nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
|
138
|
-
nshtrainer/trainer/_config.py,sha256=
|
136
|
+
nshtrainer/trainer/_config.py,sha256=FWEspBYt_bjLhUSkJApkC9pfYBTlFBHmIQRFNGpGjAc,45849
|
139
137
|
nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
|
140
138
|
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
141
139
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
@@ -148,7 +146,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
|
|
148
146
|
nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
|
149
147
|
nshtrainer/trainer/signal_connector.py,sha256=ZgbSkbthoe8MYN6rBoFf-7UDpQtc9fs9pG_FNvTYSfs,10962
|
150
148
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
151
|
-
nshtrainer/trainer/trainer.py,sha256=
|
149
|
+
nshtrainer/trainer/trainer.py,sha256=G_tHqzZCHJazhROcoKeOI5rZ5A8F8XlghiIWkdMbPR0,24387
|
152
150
|
nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
|
153
151
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
154
152
|
nshtrainer/util/code_upload.py,sha256=CpbZEBbA8EcBElUVoCPbP5zdwtNzJhS20RLaOB-q-2k,1257
|
@@ -161,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
161
159
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
162
160
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
163
161
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
164
|
-
nshtrainer-1.
|
165
|
-
nshtrainer-1.
|
166
|
-
nshtrainer-1.
|
162
|
+
nshtrainer-1.4.0.dist-info/METADATA,sha256=PIV_5Swp1HhgFU2ZBj_X1tCeOBfNhrhTXOFB1vgunno,979
|
163
|
+
nshtrainer-1.4.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
164
|
+
nshtrainer-1.4.0.dist-info/RECORD,,
|
nshtrainer/_directory.py
DELETED
@@ -1,72 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import nshconfig as C
|
7
|
-
|
8
|
-
from .callbacks.directory_setup import DirectorySetupCallbackConfig
|
9
|
-
from .loggers import LoggerConfig
|
10
|
-
|
11
|
-
log = logging.getLogger(__name__)
|
12
|
-
|
13
|
-
|
14
|
-
class DirectoryConfig(C.Config):
|
15
|
-
project_root: Path | None = None
|
16
|
-
"""
|
17
|
-
Root directory for this project.
|
18
|
-
|
19
|
-
This isn't specific to the run; it is the parent directory of all runs.
|
20
|
-
"""
|
21
|
-
|
22
|
-
logdir_basename: str = "nshtrainer"
|
23
|
-
"""Base name for the log directory."""
|
24
|
-
|
25
|
-
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
26
|
-
"""Configuration for the directory setup PyTorch Lightning callback."""
|
27
|
-
|
28
|
-
def resolve_run_root_directory(self, run_id: str) -> Path:
|
29
|
-
if (project_root_dir := self.project_root) is None:
|
30
|
-
project_root_dir = Path.cwd()
|
31
|
-
|
32
|
-
# The default base dir is $CWD/{logdir_basename}/{id}/
|
33
|
-
base_dir = project_root_dir / self.logdir_basename
|
34
|
-
base_dir.mkdir(exist_ok=True)
|
35
|
-
|
36
|
-
# Add a .gitignore file to the {logdir_basename} directory
|
37
|
-
# which will ignore all files except for the .gitignore file itself
|
38
|
-
gitignore_path = base_dir / ".gitignore"
|
39
|
-
if not gitignore_path.exists():
|
40
|
-
gitignore_path.touch()
|
41
|
-
gitignore_path.write_text("*\n")
|
42
|
-
|
43
|
-
base_dir = base_dir / run_id
|
44
|
-
base_dir.mkdir(exist_ok=True)
|
45
|
-
|
46
|
-
return base_dir
|
47
|
-
|
48
|
-
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
49
|
-
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
50
|
-
if (subdir := getattr(self, subdirectory, None)) is not None:
|
51
|
-
assert isinstance(subdir, Path), (
|
52
|
-
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
53
|
-
)
|
54
|
-
return subdir
|
55
|
-
|
56
|
-
dir = self.resolve_run_root_directory(run_id)
|
57
|
-
dir = dir / subdirectory
|
58
|
-
dir.mkdir(exist_ok=True)
|
59
|
-
return dir
|
60
|
-
|
61
|
-
def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
|
62
|
-
if (log_dir := logger.log_dir) is not None:
|
63
|
-
return log_dir
|
64
|
-
|
65
|
-
# Save to {logdir_basename}/{id}/log/{logger name}
|
66
|
-
log_dir = self.resolve_subdirectory(run_id, "log")
|
67
|
-
log_dir = log_dir / logger.resolve_logger_dirname()
|
68
|
-
# ^ NOTE: Logger must have a `name` attribute, as this is
|
69
|
-
# the discriminator for the logger registry
|
70
|
-
log_dir.mkdir(exist_ok=True)
|
71
|
-
|
72
|
-
return log_dir
|
@@ -1,15 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
6
|
-
from nshtrainer._directory import (
|
7
|
-
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
8
|
-
)
|
9
|
-
from nshtrainer._directory import LoggerConfig as LoggerConfig
|
10
|
-
|
11
|
-
__all__ = [
|
12
|
-
"DirectoryConfig",
|
13
|
-
"DirectorySetupCallbackConfig",
|
14
|
-
"LoggerConfig",
|
15
|
-
]
|
File without changes
|