nshtrainer 1.3.5__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -19,3 +19,17 @@ try:
19
19
  from . import configs as configs
20
20
  except BaseException:
21
21
  pass
22
+
23
+ try:
24
+ from importlib.metadata import PackageNotFoundError, version
25
+ except ImportError:
26
+ # For Python <3.8
27
+ from importlib_metadata import ( # pyright: ignore[reportMissingImports]
28
+ PackageNotFoundError,
29
+ version,
30
+ )
31
+
32
+ try:
33
+ __version__ = version(__name__)
34
+ except PackageNotFoundError:
35
+ __version__ = "unknown"
@@ -85,6 +85,7 @@ def _generate_checkpoint_metadata(
85
85
  trainer: Trainer,
86
86
  checkpoint_path: Path,
87
87
  metadata_path: Path,
88
+ compute_checksum: bool = True,
88
89
  ):
89
90
  checkpoint_timestamp = datetime.datetime.now()
90
91
  start_timestamp = trainer.start_time()
@@ -105,7 +106,9 @@ def _generate_checkpoint_metadata(
105
106
  # moving the checkpoint directory
106
107
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
107
108
  checkpoint_filename=checkpoint_path.name,
108
- checkpoint_checksum=compute_file_checksum(checkpoint_path),
109
+ checkpoint_checksum=compute_file_checksum(checkpoint_path)
110
+ if compute_checksum
111
+ else "",
109
112
  run_id=trainer.hparams.id,
110
113
  name=trainer.hparams.full_name,
111
114
  project=trainer.hparams.project,
nshtrainer/_hf_hub.py CHANGED
@@ -91,6 +91,9 @@ class HuggingFaceHubConfig(CallbackConfigBase):
91
91
 
92
92
  @override
93
93
  def create_callbacks(self, trainer_config):
94
+ if not self:
95
+ return
96
+
94
97
  # Attempt to login. If it fails, we'll log a warning or error based on the configuration.
95
98
  try:
96
99
  api = _api(self.token)
@@ -1,17 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ import string
4
5
  from abc import ABC, abstractmethod
6
+ from collections.abc import Callable
5
7
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Generic, Literal
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
7
9
 
8
10
  import numpy as np
9
11
  import torch
10
12
  from lightning.pytorch import Trainer
11
13
  from lightning.pytorch.callbacks import Checkpoint
12
- from typing_extensions import TypeVar, override
14
+ from typing_extensions import override
13
15
 
14
- from ..._checkpoint.metadata import CheckpointMetadata
16
+ from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
15
17
  from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
16
18
  from ..base import CallbackConfigBase
17
19
 
@@ -22,6 +24,81 @@ if TYPE_CHECKING:
22
24
  log = logging.getLogger(__name__)
23
25
 
24
26
 
27
+ class _FormatDict(dict):
28
+ """A dictionary that returns an empty string for missing keys when formatting."""
29
+
30
+ def __missing__(self, key):
31
+ log.debug(
32
+ f"Missing format key '{key}' in checkpoint filename, using empty string"
33
+ )
34
+ return ""
35
+
36
+
37
+ def _get_checkpoint_metadata(dirpath: Path) -> list[CheckpointMetadata]:
38
+ """Get all checkpoint metadata from a directory."""
39
+ return [
40
+ CheckpointMetadata.from_file(p)
41
+ for p in dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
42
+ if p.is_file() and not p.is_symlink()
43
+ ]
44
+
45
+
46
+ def _sort_checkpoint_metadata(
47
+ metas: list[CheckpointMetadata],
48
+ key_fn: Callable[[CheckpointMetadata], Any],
49
+ reverse: bool = False,
50
+ ) -> list[CheckpointMetadata]:
51
+ """Sort checkpoint metadata by the given key function."""
52
+ return sorted(metas, key=key_fn, reverse=reverse)
53
+
54
+
55
+ def _remove_checkpoints(
56
+ trainer: Trainer,
57
+ dirpath: Path,
58
+ metas_to_remove: list[CheckpointMetadata],
59
+ ) -> None:
60
+ """Remove checkpoint files and their metadata."""
61
+ for meta in metas_to_remove:
62
+ ckpt_path = dirpath / meta.checkpoint_filename
63
+ if not ckpt_path.exists():
64
+ log.warning(
65
+ f"Checkpoint file not found: {ckpt_path}\n"
66
+ "Skipping removal of the checkpoint metadata."
67
+ )
68
+ continue
69
+
70
+ remove_checkpoint(trainer, ckpt_path, metadata=True)
71
+ log.debug(f"Removed checkpoint: {ckpt_path}")
72
+
73
+
74
+ def _update_symlink(
75
+ dirpath: Path,
76
+ symlink_path: Path | None,
77
+ sort_key_fn: Callable[[CheckpointMetadata], Any],
78
+ sort_reverse: bool,
79
+ ) -> None:
80
+ """Update symlink to point to the best checkpoint."""
81
+ if symlink_path is None:
82
+ return
83
+
84
+ # Get all checkpoint metadata after any removals
85
+ remaining_metas = _get_checkpoint_metadata(dirpath)
86
+
87
+ if remaining_metas:
88
+ # Sort by the key function
89
+ remaining_metas = _sort_checkpoint_metadata(
90
+ remaining_metas, sort_key_fn, sort_reverse
91
+ )
92
+
93
+ # Link to the best checkpoint
94
+ best_meta = remaining_metas[0]
95
+ best_filepath = dirpath / best_meta.checkpoint_filename
96
+ link_checkpoint(best_filepath, symlink_path, metadata=True)
97
+ log.debug(f"Updated symlink {symlink_path.name} -> {best_filepath.name}")
98
+ else:
99
+ log.warning(f"No checkpoints found in {dirpath} to create symlink.")
100
+
101
+
25
102
  class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
26
103
  dirpath: str | Path | None = None
27
104
  """Directory path to save the checkpoint file."""
@@ -95,35 +172,27 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
95
172
  def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
96
173
  if (filename := self.config.filename) is None:
97
174
  filename = self.default_filename()
98
- filename = filename.format(**current_metrics)
99
- return self.dirpath / f"{filename}{self.extension()}"
100
-
101
- def remove_old_checkpoints(self, trainer: Trainer):
102
- if (topk := self.config.topk) == "all":
103
- return
104
175
 
105
- # Get all the checkpoint metadata
106
- metas = [
107
- CheckpointMetadata.from_file(p)
108
- for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
109
- if p.is_file() and not p.is_symlink()
176
+ # Extract all field names from the format string
177
+ field_names = [
178
+ fname for _, fname, _, _ in string.Formatter().parse(filename) if fname
110
179
  ]
111
180
 
112
- # Sort by the topk sort key
113
- metas = sorted(metas, key=self.topk_sort_key, reverse=self.topk_sort_reverse())
181
+ # Filter current_metrics to only include keys that are in the format string
182
+ format_dict = {k: v for k, v in current_metrics.items() if k in field_names}
114
183
 
115
- # Now, the metas are sorted from the best to the worst,
116
- # so we can remove the worst checkpoints
117
- for meta in metas[topk:]:
118
- if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
119
- log.warning(
120
- f"Checkpoint file not found: {old_ckpt_path}\n"
121
- "Skipping removal of the checkpoint metadata."
122
- )
123
- continue
184
+ try:
185
+ formatted_filename = filename.format(**format_dict)
186
+ except KeyError as e:
187
+ log.warning(
188
+ f"Missing key {e} in {filename=} with {format_dict=}. Using default values."
189
+ )
190
+ # Provide a simple fallback for missing keys
191
+ formatted_filename = string.Formatter().vformat(
192
+ filename, (), _FormatDict(format_dict)
193
+ )
124
194
 
125
- remove_checkpoint(trainer, old_ckpt_path, metadata=True)
126
- log.debug(f"Removed old checkpoint: {old_ckpt_path}")
195
+ return self.dirpath / f"{formatted_filename}{self.extension()}"
127
196
 
128
197
  def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
129
198
  current_metrics: dict[str, Any] = {
@@ -142,9 +211,22 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
142
211
 
143
212
  current_metrics[name] = value
144
213
 
214
+ log.debug(
215
+ f"Current metrics: {current_metrics}, {trainer.callback_metrics=}, {trainer.logged_metrics=}"
216
+ )
145
217
  return current_metrics
146
218
 
147
219
  def save_checkpoints(self, trainer: Trainer):
220
+ log.debug(
221
+ f"{type(self).__name__}.save_checkpoints() called at {trainer.current_epoch=}, {trainer.global_step=}"
222
+ )
223
+ # Also print out the current stack trace for debugging
224
+ if log.isEnabledFor(logging.DEBUG):
225
+ import traceback
226
+
227
+ stack = traceback.extract_stack()
228
+ log.debug(f"Stack trace: {''.join(traceback.format_list(stack))}")
229
+
148
230
  if self._should_skip_saving_checkpoint(trainer):
149
231
  return
150
232
 
@@ -156,22 +238,73 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
156
238
  f"but got {type(trainer).__name__}"
157
239
  )
158
240
 
159
- # Save the new checkpoint
160
- filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
161
- trainer.save_checkpoint(filepath, self.config.save_weights_only)
241
+ current_metrics = self.current_metrics(trainer)
242
+ filepath = self.resolve_checkpoint_path(current_metrics)
243
+
244
+ # Get all existing checkpoint metadata
245
+ existing_metas = _get_checkpoint_metadata(self.dirpath)
246
+
247
+ # Determine which checkpoints to remove
248
+ to_remove: list[CheckpointMetadata] = []
249
+ should_save = True
250
+
251
+ # Check if we should save this checkpoint
252
+ if (topk := self.config.topk) != "all" and len(existing_metas) >= topk:
253
+ # Generate hypothetical metadata for the current checkpoint
254
+ hypothetical_meta = _generate_checkpoint_metadata(
255
+ trainer=trainer,
256
+ checkpoint_path=filepath,
257
+ metadata_path=filepath.with_suffix(CheckpointMetadata.PATH_SUFFIX),
258
+ compute_checksum=False,
259
+ )
260
+
261
+ # Add the hypothetical metadata to the list and sort
262
+ metas = _sort_checkpoint_metadata(
263
+ [*existing_metas, hypothetical_meta],
264
+ self.topk_sort_key,
265
+ self.topk_sort_reverse(),
266
+ )
267
+
268
+ # If the hypothetical metadata is not in the top-k, skip saving
269
+ if hypothetical_meta not in metas[:topk]:
270
+ log.debug(
271
+ f"Skipping checkpoint save: would not make top {topk} "
272
+ f"based on {self.topk_sort_key.__name__}"
273
+ )
274
+ should_save = False
275
+ else:
276
+ # Determine which existing checkpoints to remove
277
+ to_remove = metas[topk:]
278
+ assert hypothetical_meta not in to_remove, (
279
+ "Hypothetical metadata should not be in the to_remove list."
280
+ )
281
+ log.debug(
282
+ f"Removing checkpoints: {[meta.checkpoint_filename for meta in to_remove]} "
283
+ f"and saving the new checkpoint: {hypothetical_meta.checkpoint_filename}"
284
+ )
162
285
 
163
- if trainer.hparams.save_checkpoint_metadata and trainer.is_global_zero:
164
- # Remove old checkpoints
165
- self.remove_old_checkpoints(trainer)
286
+ # Only save if it would make it into the top-k
287
+ if should_save:
288
+ # Save the new checkpoint
289
+ trainer.save_checkpoint(
290
+ filepath,
291
+ weights_only=self.config.save_weights_only,
292
+ )
166
293
 
167
- # Create the latest symlink
168
- if (symlink_filename := self.symlink_path()) is not None:
169
- symlink_path = self.dirpath / symlink_filename
170
- link_checkpoint(filepath, symlink_path, metadata=True)
171
- log.debug(f"Created latest symlink: {symlink_path}")
294
+ if trainer.is_global_zero:
295
+ # Remove old checkpoints that should be deleted
296
+ if to_remove:
297
+ _remove_checkpoints(trainer, self.dirpath, to_remove)
298
+
299
+ # Update the symlink to point to the best checkpoint
300
+ _update_symlink(
301
+ self.dirpath,
302
+ self.symlink_path(),
303
+ self.topk_sort_key,
304
+ self.topk_sort_reverse(),
305
+ )
172
306
 
173
- # Barrier to ensure all processes have saved the checkpoint,
174
- # deleted the old checkpoints, and created the symlink before continuing
307
+ # Barrier to ensure all processes have completed checkpoint operations
175
308
  trainer.strategy.barrier()
176
309
 
177
310
  def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  from typing import Literal
4
5
 
5
6
  from lightning.pytorch.callbacks import LearningRateMonitor
6
- from typing_extensions import final
7
+ from typing_extensions import final, override
7
8
 
8
9
  from .base import CallbackConfigBase, callback_registry
9
10
 
11
+ log = logging.getLogger(__name__)
12
+
10
13
 
11
14
  @final
12
15
  @callback_registry.register
@@ -28,7 +31,12 @@ class LearningRateMonitorConfig(CallbackConfigBase):
28
31
  Option to also log the weight decay values of the optimizer. Defaults to False.
29
32
  """
30
33
 
34
+ @override
31
35
  def create_callbacks(self, trainer_config):
36
+ if not list(trainer_config.enabled_loggers()):
37
+ log.warning("No loggers enabled. LearningRateMonitor will not be used.")
38
+ return
39
+
32
40
  yield LearningRateMonitor(
33
41
  logging_interval=self.logging_interval,
34
42
  log_momentum=self.log_momentum,
@@ -5,7 +5,6 @@ __codegen__ = True
5
5
  from nshtrainer import MetricConfig as MetricConfig
6
6
  from nshtrainer import TrainerConfig as TrainerConfig
7
7
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
- from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
8
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
10
9
  from nshtrainer._hf_hub import (
11
10
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
@@ -126,9 +125,9 @@ from nshtrainer.trainer._config import (
126
125
  CheckpointCallbackConfig as CheckpointCallbackConfig,
127
126
  )
128
127
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
128
+ from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
129
129
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
130
130
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
131
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
132
131
  from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
133
132
  from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
134
133
  from nshtrainer.trainer.accelerator import (
@@ -227,7 +226,6 @@ from nshtrainer.util.config import EpochsConfig as EpochsConfig
227
226
  from nshtrainer.util.config import StepsConfig as StepsConfig
228
227
 
229
228
  from . import _checkpoint as _checkpoint
230
- from . import _directory as _directory
231
229
  from . import _hf_hub as _hf_hub
232
230
  from . import callbacks as callbacks
233
231
  from . import loggers as loggers
@@ -338,7 +336,6 @@ __all__ = [
338
336
  "RpropConfig",
339
337
  "SGDConfig",
340
338
  "SLURMEnvironmentPlugin",
341
- "SanityCheckingConfig",
342
339
  "SharedParametersCallbackConfig",
343
340
  "SiLUNonlinearityConfig",
344
341
  "SigmoidNonlinearityConfig",
@@ -367,7 +364,6 @@ __all__ = [
367
364
  "XLAEnvironmentPlugin",
368
365
  "XLAPluginConfig",
369
366
  "_checkpoint",
370
- "_directory",
371
367
  "_hf_hub",
372
368
  "accelerator_registry",
373
369
  "callback_registry",
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
+ from nshtrainer.trainer._config import (
26
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
27
+ )
25
28
  from nshtrainer.trainer._config import (
26
29
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
27
30
  )
@@ -51,7 +54,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
51
54
  from nshtrainer.trainer._config import (
52
55
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
53
56
  )
54
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
55
57
  from nshtrainer.trainer._config import (
56
58
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
57
59
  )
@@ -152,6 +154,7 @@ __all__ = [
152
154
  "DebugFlagCallbackConfig",
153
155
  "DeepSpeedPluginConfig",
154
156
  "DirectoryConfig",
157
+ "DirectorySetupCallbackConfig",
155
158
  "DistributedPredictionWriterConfig",
156
159
  "DoublePrecisionPluginConfig",
157
160
  "EarlyStoppingCallbackConfig",
@@ -180,7 +183,6 @@ __all__ = [
180
183
  "ProfilerConfig",
181
184
  "RLPSanityChecksCallbackConfig",
182
185
  "SLURMEnvironmentPlugin",
183
- "SanityCheckingConfig",
184
186
  "SharedParametersCallbackConfig",
185
187
  "StrategyConfig",
186
188
  "StrategyConfigBase",
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
+ from nshtrainer.trainer._config import (
22
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
23
+ )
21
24
  from nshtrainer.trainer._config import (
22
25
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
23
26
  )
@@ -48,7 +51,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
48
51
  from nshtrainer.trainer._config import (
49
52
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
50
53
  )
51
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
52
54
  from nshtrainer.trainer._config import (
53
55
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
54
56
  )
@@ -70,6 +72,7 @@ __all__ = [
70
72
  "CheckpointSavingConfig",
71
73
  "DebugFlagCallbackConfig",
72
74
  "DirectoryConfig",
75
+ "DirectorySetupCallbackConfig",
73
76
  "EarlyStoppingCallbackConfig",
74
77
  "EnvironmentConfig",
75
78
  "GradientClippingConfig",
@@ -86,7 +89,6 @@ __all__ = [
86
89
  "PluginConfig",
87
90
  "ProfilerConfig",
88
91
  "RLPSanityChecksCallbackConfig",
89
- "SanityCheckingConfig",
90
92
  "SharedParametersCallbackConfig",
91
93
  "StrategyConfig",
92
94
  "TensorboardLoggerConfig",
@@ -26,7 +26,6 @@ from lightning.pytorch.profilers import Profiler
26
26
  from lightning.pytorch.strategies.strategy import Strategy
27
27
  from typing_extensions import TypeAliasType, TypedDict, override
28
28
 
29
- from .._directory import DirectoryConfig
30
29
  from .._hf_hub import HuggingFaceHubConfig
31
30
  from ..callbacks import (
32
31
  BestCheckpointCallbackConfig,
@@ -38,6 +37,7 @@ from ..callbacks import (
38
37
  )
39
38
  from ..callbacks.base import CallbackConfigBase
40
39
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
40
+ from ..callbacks.directory_setup import DirectorySetupCallbackConfig
41
41
  from ..callbacks.log_epoch import LogEpochCallbackConfig
42
42
  from ..callbacks.lr_monitor import LearningRateMonitorConfig
43
43
  from ..callbacks.metric_validation import MetricValidationCallbackConfig
@@ -352,19 +352,74 @@ class LightningTrainerKwargs(TypedDict, total=False):
352
352
  """
353
353
 
354
354
 
355
- class SanityCheckingConfig(C.Config):
356
- reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
355
+ DEFAULT_LOGDIR_BASENAME = "nshtrainer_logs"
356
+ """Default base name for the log directory."""
357
+
358
+
359
+ class DirectoryConfig(C.Config):
360
+ project_root: Path | None = None
357
361
  """
358
- If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
359
- - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
360
- - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
361
- Valid values are: "disable", "warn", "error".
362
+ Root directory for this project.
363
+
364
+ This isn't specific to the current run; it is the parent directory of all runs.
362
365
  """
363
366
 
367
+ logdir_basename: str = DEFAULT_LOGDIR_BASENAME
368
+ """Base name for the log directory."""
369
+
370
+ setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
371
+ """Configuration for the directory setup PyTorch Lightning callback."""
372
+
373
+ def resolve_run_root_directory(self, run_id: str) -> Path:
374
+ if (project_root_dir := self.project_root) is None:
375
+ project_root_dir = Path.cwd()
376
+
377
+ # The default base dir is $CWD/{logdir_basename}/{id}/
378
+ base_dir = project_root_dir / self.logdir_basename
379
+ base_dir.mkdir(exist_ok=True)
380
+
381
+ # Add a .gitignore file to the {logdir_basename} directory
382
+ # which will ignore all files except for the .gitignore file itself
383
+ gitignore_path = base_dir / ".gitignore"
384
+ if not gitignore_path.exists():
385
+ gitignore_path.touch()
386
+ gitignore_path.write_text("*\n")
387
+
388
+ base_dir = base_dir / run_id
389
+ base_dir.mkdir(exist_ok=True)
390
+
391
+ return base_dir
392
+
393
+ def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
394
+ # The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
395
+ if (subdir := getattr(self, subdirectory, None)) is not None:
396
+ assert isinstance(subdir, Path), (
397
+ f"Expected a Path for {subdirectory}, got {type(subdir)}"
398
+ )
399
+ return subdir
400
+
401
+ dir = self.resolve_run_root_directory(run_id)
402
+ dir = dir / subdirectory
403
+ dir.mkdir(exist_ok=True)
404
+ return dir
405
+
406
+ def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
407
+ if (log_dir := logger.log_dir) is not None:
408
+ return log_dir
409
+
410
+ # Save to {logdir_basename}/{id}/log/{logger name}
411
+ log_dir = self.resolve_subdirectory(run_id, "log")
412
+ log_dir = log_dir / logger.resolve_logger_dirname()
413
+ # ^ NOTE: Logger must have a `name` attribute, as this is
414
+ # the discriminator for the logger registry
415
+ log_dir.mkdir(exist_ok=True)
416
+
417
+ return log_dir
418
+
364
419
 
365
420
  class TrainerConfig(C.Config):
366
421
  # region Active Run Configuration
367
- id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
422
+ id: Annotated[str, C.AllowMissing()] = C.MISSING
368
423
  """ID of the run."""
369
424
  name: list[str] = []
370
425
  """Run name in parts. Full name is constructed by joining the parts with spaces."""
@@ -393,39 +448,6 @@ class TrainerConfig(C.Config):
393
448
 
394
449
  directory: DirectoryConfig = DirectoryConfig()
395
450
  """Directory configuration options."""
396
-
397
- _rng: ClassVar[np.random.Generator | None] = None
398
-
399
- @classmethod
400
- def generate_id(cls, *, length: int = 8) -> str:
401
- """
402
- Generate a random ID of specified length.
403
-
404
- """
405
- if (rng := cls._rng) is None:
406
- rng = np.random.default_rng()
407
-
408
- alphabet = list(string.ascii_lowercase + string.digits)
409
-
410
- id = "".join(rng.choice(alphabet) for _ in range(length))
411
- return id
412
-
413
- @classmethod
414
- def set_seed(cls, seed: int | None = None) -> None:
415
- """
416
- Set the seed for the random number generator.
417
-
418
- Args:
419
- seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
420
-
421
- Returns:
422
- None
423
- """
424
- if seed is None:
425
- seed = int(time.time() * 1000)
426
- log.critical(f"Seeding {cls.__name__} with seed {seed}")
427
- cls._rng = np.random.default_rng(seed)
428
-
429
451
  # endregion
430
452
 
431
453
  primary_metric: MetricConfig | None = None
@@ -695,8 +717,9 @@ class TrainerConfig(C.Config):
695
717
 
696
718
  auto_set_default_root_dir: bool = True
697
719
  """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
698
- save_checkpoint_metadata: bool = True
699
- """If enabled, will save additional metadata whenever a checkpoint is saved."""
720
+ save_checkpoint_metadata: Literal[True] = True
721
+ """Will save additional metadata whenever a checkpoint is saved.
722
+ This is a core feature of nshtrainer and cannot be disabled."""
700
723
  auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
701
724
  """If enabled, will automatically set the debug flag to True if:
702
725
  - The trainer is running in fast_dev_run mode.
@@ -755,40 +778,40 @@ class TrainerConfig(C.Config):
755
778
  None,
756
779
  )
757
780
 
758
- def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
759
- yield self.directory.setup_callback
760
- yield self.early_stopping
761
- yield self.checkpoint_saving
762
- yield self.lr_monitor
763
- yield from (
764
- logger_config
765
- for logger_config in self.enabled_loggers()
766
- if logger_config is not None
767
- and isinstance(logger_config, CallbackConfigBase)
768
- )
769
- yield self.log_epoch
770
- yield self.log_norms
771
- yield self.hf_hub
772
- yield self.shared_parameters
773
- yield self.reduce_lr_on_plateau_sanity_checking
774
- yield self.auto_set_debug_flag
775
- yield self.auto_validate_metrics
776
- yield from self.callbacks
781
+ # region Helper Methods
782
+ def id_(self, value: str):
783
+ """
784
+ Set the id for the trainer configuration in-place.
777
785
 
778
- def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
779
- # Disable all loggers if barebones mode is enabled
780
- if self.barebones:
781
- return
786
+ Parameters
787
+ ----------
788
+ value : str
789
+ The id value to set
782
790
 
783
- yield from self.enabled_loggers()
784
- yield self.actsave_logger
791
+ Returns
792
+ -------
793
+ self
794
+ Returns self for method chaining
795
+ """
796
+ self.id = value
797
+ return self
785
798
 
786
- def _nshtrainer_validate_before_run(self):
787
- # shared_parameters is not supported under barebones mode
788
- if self.barebones and self.shared_parameters:
789
- raise ValueError("shared_parameters is not supported under barebones mode")
799
+ def with_id(self, value: str):
800
+ """
801
+ Create a copy of the current configuration with an updated id.
802
+
803
+ Parameters
804
+ ----------
805
+ value : str
806
+ The id value to set
807
+
808
+ Returns
809
+ -------
810
+ TrainerConfig
811
+ A new instance of the configuration with the updated id
812
+ """
813
+ return copy.deepcopy(self).id_(value)
790
814
 
791
- # region Helper Methods
792
815
  def fast_dev_run_(self, value: int | bool = True, /):
793
816
  """
794
817
  Enables fast_dev_run mode for the trainer.
@@ -831,6 +854,349 @@ class TrainerConfig(C.Config):
831
854
  """
832
855
  return copy.deepcopy(self).project_root_(project_root)
833
856
 
857
+ def name_(self, *parts: str):
858
+ """
859
+ Set the name for the trainer configuration in-place.
860
+
861
+ Parameters
862
+ ----------
863
+ *parts : str
864
+ The parts of the name to set. Will be joined with spaces.
865
+
866
+ Returns
867
+ -------
868
+ self
869
+ Returns self for method chaining
870
+ """
871
+ self.name = list(parts)
872
+ return self
873
+
874
+ def with_name(self, *parts: str):
875
+ """
876
+ Create a copy of the current configuration with an updated name.
877
+
878
+ Parameters
879
+ ----------
880
+ *parts : str
881
+ The parts of the name to set. Will be joined with spaces.
882
+
883
+ Returns
884
+ -------
885
+ TrainerConfig
886
+ A new instance of the configuration with the updated name
887
+ """
888
+ return copy.deepcopy(self).name_(*parts)
889
+
890
+ def project_(self, project: str | None):
891
+ """
892
+ Set the project name for the trainer configuration in-place.
893
+
894
+ Parameters
895
+ ----------
896
+ project : str | None
897
+ The project name to set
898
+
899
+ Returns
900
+ -------
901
+ self
902
+ Returns self for method chaining
903
+ """
904
+ self.project = project
905
+ return self
906
+
907
+ def with_project(self, project: str | None):
908
+ """
909
+ Create a copy of the current configuration with an updated project name.
910
+
911
+ Parameters
912
+ ----------
913
+ project : str | None
914
+ The project name to set
915
+
916
+ Returns
917
+ -------
918
+ TrainerConfig
919
+ A new instance of the configuration with the updated project name
920
+ """
921
+ return copy.deepcopy(self).project_(project)
922
+
923
+ def tags_(self, *tags: str):
924
+ """
925
+ Set the tags for the trainer configuration in-place.
926
+
927
+ Parameters
928
+ ----------
929
+ *tags : str
930
+ The tags to set
931
+
932
+ Returns
933
+ -------
934
+ self
935
+ Returns self for method chaining
936
+ """
937
+ self.tags = list(tags)
938
+ return self
939
+
940
+ def with_tags(self, *tags: str):
941
+ """
942
+ Create a copy of the current configuration with updated tags.
943
+
944
+ Parameters
945
+ ----------
946
+ *tags : str
947
+ The tags to set
948
+
949
+ Returns
950
+ -------
951
+ TrainerConfig
952
+ A new instance of the configuration with the updated tags
953
+ """
954
+ return copy.deepcopy(self).tags_(*tags)
955
+
956
+ def add_tags_(self, *tags: str):
957
+ """
958
+ Add tags to the trainer configuration in-place.
959
+
960
+ Parameters
961
+ ----------
962
+ *tags : str
963
+ The tags to add
964
+
965
+ Returns
966
+ -------
967
+ self
968
+ Returns self for method chaining
969
+ """
970
+ self.tags.extend(tags)
971
+ return self
972
+
973
+ def with_added_tags(self, *tags: str):
974
+ """
975
+ Create a copy of the current configuration with additional tags.
976
+
977
+ Parameters
978
+ ----------
979
+ *tags : str
980
+ The tags to add
981
+
982
+ Returns
983
+ -------
984
+ TrainerConfig
985
+ A new instance of the configuration with the additional tags
986
+ """
987
+ return copy.deepcopy(self).add_tags_(*tags)
988
+
989
+ def notes_(self, *notes: str):
990
+ """
991
+ Set the notes for the trainer configuration in-place.
992
+
993
+ Parameters
994
+ ----------
995
+ *notes : str
996
+ The notes to set
997
+
998
+ Returns
999
+ -------
1000
+ self
1001
+ Returns self for method chaining
1002
+ """
1003
+ self.notes = list(notes)
1004
+ return self
1005
+
1006
+ def with_notes(self, *notes: str):
1007
+ """
1008
+ Create a copy of the current configuration with updated notes.
1009
+
1010
+ Parameters
1011
+ ----------
1012
+ *notes : str
1013
+ The notes to set
1014
+
1015
+ Returns
1016
+ -------
1017
+ TrainerConfig
1018
+ A new instance of the configuration with the updated notes
1019
+ """
1020
+ return copy.deepcopy(self).notes_(*notes)
1021
+
1022
+ def add_notes_(self, *notes: str):
1023
+ """
1024
+ Add notes to the trainer configuration in-place.
1025
+
1026
+ Parameters
1027
+ ----------
1028
+ *notes : str
1029
+ The notes to add
1030
+
1031
+ Returns
1032
+ -------
1033
+ self
1034
+ Returns self for method chaining
1035
+ """
1036
+ self.notes.extend(notes)
1037
+ return self
1038
+
1039
+ def with_added_notes(self, *notes: str):
1040
+ """
1041
+ Create a copy of the current configuration with additional notes.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ *notes : str
1046
+ The notes to add
1047
+
1048
+ Returns
1049
+ -------
1050
+ TrainerConfig
1051
+ A new instance of the configuration with the additional notes
1052
+ """
1053
+ return copy.deepcopy(self).add_notes_(*notes)
1054
+
1055
+ def meta_(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
1056
+ """
1057
+ Update the `meta` dictionary in-place with the provided key-value pairs.
1058
+
1059
+ This method allows updating the meta information associated with the trainer
1060
+ configuration by either passing a dictionary or keyword arguments.
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ meta : dict[str, Any] | None, optional
1065
+ A dictionary containing meta information to be added, by default None
1066
+ **kwargs : Any
1067
+ Additional key-value pairs to be added to the meta dictionary
1068
+
1069
+ Returns
1070
+ -------
1071
+ self
1072
+ Returns self for method chaining
1073
+ """
1074
+ if meta is not None:
1075
+ self.meta.update(meta)
1076
+ self.meta.update(kwargs)
1077
+ return self
1078
+
1079
+ def with_meta(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
1080
+ """
1081
+ Create a copy of the current configuration with updated meta information.
1082
+
1083
+ This method is similar to `meta_`, but it returns a new instance of the configuration
1084
+ with the updated meta information instead of modifying the current instance.
1085
+
1086
+ Parameters
1087
+ ----------
1088
+ meta : dict[str, Any] | None, optional
1089
+ A dictionary containing meta information to be added, by default None
1090
+ **kwargs : Any
1091
+ Additional key-value pairs to be added to the meta dictionary
1092
+
1093
+ Returns
1094
+ -------
1095
+ TrainerConfig
1096
+ A new instance of the configuration with updated meta information
1097
+ """
1098
+
1099
+ return self.model_copy(deep=True).meta_(meta, **kwargs)
1100
+
1101
+ def debug_(self, value: bool = True):
1102
+ """
1103
+ Set the debug flag for the trainer configuration in-place.
1104
+
1105
+ Parameters
1106
+ ----------
1107
+ value : bool, optional
1108
+ The debug flag value to set, by default True
1109
+
1110
+ Returns
1111
+ -------
1112
+ self
1113
+ Returns self for method chaining
1114
+ """
1115
+ self.debug = value
1116
+ return self
1117
+
1118
+ def with_debug(self, value: bool = True):
1119
+ """
1120
+ Create a copy of the current configuration with an updated debug flag.
1121
+
1122
+ Parameters
1123
+ ----------
1124
+ value : bool, optional
1125
+ The debug flag value to set, by default True
1126
+
1127
+ Returns
1128
+ -------
1129
+ TrainerConfig
1130
+ A new instance of the configuration with the updated debug flag
1131
+ """
1132
+ return copy.deepcopy(self).debug_(value)
1133
+
1134
+ def ckpt_path_(self, path: Literal["none"] | str | Path | None):
1135
+ """
1136
+ Set the checkpoint path for the trainer configuration in-place.
1137
+
1138
+ Parameters
1139
+ ----------
1140
+ path : Literal["none"] | str | Path | None
1141
+ The checkpoint path to set
1142
+
1143
+ Returns
1144
+ -------
1145
+ self
1146
+ Returns self for method chaining
1147
+ """
1148
+ self.ckpt_path = path
1149
+ return self
1150
+
1151
+ def with_ckpt_path(self, path: Literal["none"] | str | Path | None):
1152
+ """
1153
+ Create a copy of the current configuration with an updated checkpoint path.
1154
+
1155
+ Parameters
1156
+ ----------
1157
+ path : Literal["none"] | str | Path | None
1158
+ The checkpoint path to set
1159
+
1160
+ Returns
1161
+ -------
1162
+ TrainerConfig
1163
+ A new instance of the configuration with the updated checkpoint path
1164
+ """
1165
+ return copy.deepcopy(self).ckpt_path_(path)
1166
+
1167
+ def barebones_(self, value: bool = True):
1168
+ """
1169
+ Set the barebones flag for the trainer configuration in-place.
1170
+
1171
+ Parameters
1172
+ ----------
1173
+ value : bool, optional
1174
+ The barebones flag value to set, by default True
1175
+
1176
+ Returns
1177
+ -------
1178
+ self
1179
+ Returns self for method chaining
1180
+ """
1181
+ self.barebones = value
1182
+ return self
1183
+
1184
+ def with_barebones(self, value: bool = True):
1185
+ """
1186
+ Create a copy of the current configuration with an updated barebones flag.
1187
+
1188
+ Parameters
1189
+ ----------
1190
+ value : bool, optional
1191
+ The barebones flag value to set, by default True
1192
+
1193
+ Returns
1194
+ -------
1195
+ TrainerConfig
1196
+ A new instance of the configuration with the updated barebones flag
1197
+ """
1198
+ return copy.deepcopy(self).barebones_(value)
1199
+
834
1200
  def reset_run(
835
1201
  self,
836
1202
  *,
@@ -873,3 +1239,89 @@ class TrainerConfig(C.Config):
873
1239
  return config
874
1240
 
875
1241
  # endregion
1242
+
1243
+ # region Random ID Generation
1244
+ _rng: ClassVar[np.random.Generator | None] = None
1245
+
1246
+ @classmethod
1247
+ def generate_id(cls, *, length: int = 8) -> str:
1248
+ """
1249
+ Generate a random ID of specified length.
1250
+
1251
+ """
1252
+ if (rng := cls._rng) is None:
1253
+ rng = np.random.default_rng()
1254
+
1255
+ alphabet = list(string.ascii_lowercase + string.digits)
1256
+
1257
+ id = "".join(rng.choice(alphabet) for _ in range(length))
1258
+ return id
1259
+
1260
+ @classmethod
1261
+ def set_seed(cls, seed: int | None = None) -> None:
1262
+ """
1263
+ Set the seed for the random number generator.
1264
+
1265
+ Args:
1266
+ seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
1267
+
1268
+ Returns:
1269
+ None
1270
+ """
1271
+ if seed is None:
1272
+ seed = int(time.time() * 1000)
1273
+ log.critical(f"Seeding {cls.__name__} with seed {seed}")
1274
+ cls._rng = np.random.default_rng(seed)
1275
+
1276
+ # endregion
1277
+
1278
+ # region Internal Methods
1279
+ def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
1280
+ yield self.directory.setup_callback
1281
+ yield self.early_stopping
1282
+ yield self.checkpoint_saving
1283
+ yield self.lr_monitor
1284
+ yield from (
1285
+ logger_config
1286
+ for logger_config in self.enabled_loggers()
1287
+ if logger_config is not None
1288
+ and isinstance(logger_config, CallbackConfigBase)
1289
+ )
1290
+ yield self.log_epoch
1291
+ yield self.log_norms
1292
+ yield self.hf_hub
1293
+ yield self.shared_parameters
1294
+ yield self.reduce_lr_on_plateau_sanity_checking
1295
+ yield self.auto_set_debug_flag
1296
+ yield self.auto_validate_metrics
1297
+ yield from self.callbacks
1298
+
1299
+ def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
1300
+ # Disable all loggers if barebones mode is enabled
1301
+ if self.barebones:
1302
+ return
1303
+
1304
+ yield from self.enabled_loggers()
1305
+ yield self.actsave_logger
1306
+
1307
+ def _nshtrainer_validate_before_run(self):
1308
+ # shared_parameters is not supported under barebones mode
1309
+ if self.barebones and self.shared_parameters:
1310
+ raise ValueError("shared_parameters is not supported under barebones mode")
1311
+
1312
+ if not self.save_checkpoint_metadata:
1313
+ raise ValueError(
1314
+ "save_checkpoint_metadata must be True. This is a core feature of nshtrainer and cannot be disabled."
1315
+ )
1316
+
1317
+ def _nshtrainer_set_id_if_missing(self):
1318
+ """
1319
+ Set the ID for the configuration object if it is missing.
1320
+ """
1321
+ if self.id is C.MISSING:
1322
+ self.id = self.generate_id()
1323
+ log.info(f"TrainerConfig's run ID is missing, setting to {self.id}.")
1324
+ else:
1325
+ log.debug(f"TrainerConfig's run ID is already set to {self.id}.")
1326
+
1327
+ # endregion
@@ -45,6 +45,9 @@ patch_log_hparams_function()
45
45
 
46
46
 
47
47
  class Trainer(LightningTrainer):
48
+ profiler: Profiler
49
+ """Profiler used for profiling the training process."""
50
+
48
51
  CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
49
52
 
50
53
  @property
@@ -316,6 +319,7 @@ class Trainer(LightningTrainer):
316
319
  f"Trainer hparams must either be an instance of {hparams_cls} or a mapping. "
317
320
  f"Got {type(hparams)=} instead."
318
321
  )
322
+ hparams._nshtrainer_set_id_if_missing()
319
323
  hparams = hparams.model_deep_validate()
320
324
  hparams._nshtrainer_validate_before_run()
321
325
 
@@ -468,6 +472,11 @@ class Trainer(LightningTrainer):
468
472
  weights_only: bool = False,
469
473
  storage_options: Any | None = None,
470
474
  ):
475
+ assert self.hparams.save_checkpoint_metadata, (
476
+ "Checkpoint metadata is not enabled. "
477
+ "Please set `hparams.save_checkpoint_metadata=True`."
478
+ )
479
+
471
480
  filepath = Path(filepath)
472
481
 
473
482
  if self.model is None:
@@ -475,7 +484,7 @@ class Trainer(LightningTrainer):
475
484
  "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
476
485
  " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
477
486
  )
478
- with self.profiler.profile("save_checkpoint"): # type: ignore
487
+ with self.profiler.profile("save_checkpoint"):
479
488
  checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
480
489
  # Update the checkpoint for the trainer hyperparameters
481
490
  checkpoint[self.CHECKPOINT_HYPER_PARAMS_KEY] = self.hparams.model_dump(
@@ -488,7 +497,7 @@ class Trainer(LightningTrainer):
488
497
 
489
498
  # Save the checkpoint metadata
490
499
  metadata_path = None
491
- if self.hparams.save_checkpoint_metadata and self.is_global_zero:
500
+ if self.is_global_zero:
492
501
  # Generate the metadata and write to disk
493
502
  metadata_path = write_checkpoint_metadata(self, filepath)
494
503
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.3.5
3
+ Version: 1.4.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,16 +1,15 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
- nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
2
+ nshtrainer/__init__.py,sha256=RI_2B_IUWa10B6H5TAuWtE5FWX1X4ue-J4dTDaF2-lQ,1035
3
3
  nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
4
- nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
4
+ nshtrainer/_checkpoint/metadata.py,sha256=El9Ip8jGA7mAN5rAMpVfg1dfUe2dGoOOfvF1JfYJGHM,5676
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
6
- nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
7
6
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
- nshtrainer/_hf_hub.py,sha256=kfN0wDxK5JWKKGZnX_706i0KXGhaS19p581LDTPxlRE,13996
7
+ nshtrainer/_hf_hub.py,sha256=OB4252GJ6AbKNCRmHVvEglvjYVMUN822BFYECABxfZU,14037
9
8
  nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
10
9
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
10
  nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
12
11
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=f7lpk8W4xqxk3PolBEU3AWt9VTIpoLW7wMUhC5DNm3c,6345
12
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=BjgfCXsf4Ihf1MNKkHBUwjHMLwc04PZO-2Bx-LdAazg,11010
14
13
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7hb96BxHanmvn7PCCbqq0E,2648
15
14
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
16
15
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
@@ -23,7 +22,7 @@ nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH
23
22
  nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
24
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
25
24
  nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
26
- nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
25
+ nshtrainer/callbacks/lr_monitor.py,sha256=v45ehnwNO987087HfiOY5aIrVRbwdKMgPYRFHs1fyEE,1444
27
26
  nshtrainer/callbacks/metric_validation.py,sha256=4RDr1FuNKfro-6QEtmcFqT4iNf2twmJVNk9y-8nq9bg,2882
28
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
29
28
  nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
@@ -33,10 +32,9 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
33
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
34
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
35
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
36
- nshtrainer/configs/__init__.py,sha256=KD3uClMwnA4LfQ7rY5phDdUbp3j8NoZfaGbGPbpaJVs,15848
35
+ nshtrainer/configs/__init__.py,sha256=-yJ5Uk9VkANqfk-QnX2aynL0jSf7cJQuQNzT1GAE1x8,15684
37
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
38
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
39
- nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
40
38
  nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
41
39
  nshtrainer/configs/callbacks/__init__.py,sha256=tP9urR73NIanyxpbi4EERsxOnGNiptbQpmsj-v53a38,4774
42
40
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
@@ -85,8 +83,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
85
83
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
86
84
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
87
85
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
88
- nshtrainer/configs/trainer/__init__.py,sha256=YLlDOUYDp_qURHhcmhCxTcY6K5AbmoTxdzBPB9SEZII,8040
89
- nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
86
+ nshtrainer/configs/trainer/__init__.py,sha256=DM2PlB4WRDZ_dqEeW91LbKRFa4sIF_pETU0T9GYJ5-g,8073
87
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=z5UpuXktBanLOYNkkbgbbHE06iQtcSuAKTpnx2TLmCo,3850
90
88
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
91
89
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
92
90
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
@@ -135,7 +133,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
135
133
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
136
134
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
137
135
  nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
138
- nshtrainer/trainer/_config.py,sha256=SohR7uxANnP3xrrcW_mAjk6TuDamsW5Qdk3dlnPinDw,33457
136
+ nshtrainer/trainer/_config.py,sha256=FWEspBYt_bjLhUSkJApkC9pfYBTlFBHmIQRFNGpGjAc,45849
139
137
  nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
140
138
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
141
139
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
@@ -148,7 +146,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
148
146
  nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
149
147
  nshtrainer/trainer/signal_connector.py,sha256=ZgbSkbthoe8MYN6rBoFf-7UDpQtc9fs9pG_FNvTYSfs,10962
150
148
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
151
- nshtrainer/trainer/trainer.py,sha256=6oky6E8cjGqUNzJGyyTO551pE9A6YueOv5oxg1fZVR0,24129
149
+ nshtrainer/trainer/trainer.py,sha256=G_tHqzZCHJazhROcoKeOI5rZ5A8F8XlghiIWkdMbPR0,24387
152
150
  nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
153
151
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
154
152
  nshtrainer/util/code_upload.py,sha256=CpbZEBbA8EcBElUVoCPbP5zdwtNzJhS20RLaOB-q-2k,1257
@@ -161,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
161
159
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
162
160
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
163
161
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
164
- nshtrainer-1.3.5.dist-info/METADATA,sha256=GUU8QgA8rxeCX1Z9FfwSvZQ46f0xsMvtm4p1Uz8uEwE,979
165
- nshtrainer-1.3.5.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
166
- nshtrainer-1.3.5.dist-info/RECORD,,
162
+ nshtrainer-1.4.0.dist-info/METADATA,sha256=PIV_5Swp1HhgFU2ZBj_X1tCeOBfNhrhTXOFB1vgunno,979
163
+ nshtrainer-1.4.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
+ nshtrainer-1.4.0.dist-info/RECORD,,
nshtrainer/_directory.py DELETED
@@ -1,72 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- from pathlib import Path
5
-
6
- import nshconfig as C
7
-
8
- from .callbacks.directory_setup import DirectorySetupCallbackConfig
9
- from .loggers import LoggerConfig
10
-
11
- log = logging.getLogger(__name__)
12
-
13
-
14
- class DirectoryConfig(C.Config):
15
- project_root: Path | None = None
16
- """
17
- Root directory for this project.
18
-
19
- This isn't specific to the run; it is the parent directory of all runs.
20
- """
21
-
22
- logdir_basename: str = "nshtrainer"
23
- """Base name for the log directory."""
24
-
25
- setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
26
- """Configuration for the directory setup PyTorch Lightning callback."""
27
-
28
- def resolve_run_root_directory(self, run_id: str) -> Path:
29
- if (project_root_dir := self.project_root) is None:
30
- project_root_dir = Path.cwd()
31
-
32
- # The default base dir is $CWD/{logdir_basename}/{id}/
33
- base_dir = project_root_dir / self.logdir_basename
34
- base_dir.mkdir(exist_ok=True)
35
-
36
- # Add a .gitignore file to the {logdir_basename} directory
37
- # which will ignore all files except for the .gitignore file itself
38
- gitignore_path = base_dir / ".gitignore"
39
- if not gitignore_path.exists():
40
- gitignore_path.touch()
41
- gitignore_path.write_text("*\n")
42
-
43
- base_dir = base_dir / run_id
44
- base_dir.mkdir(exist_ok=True)
45
-
46
- return base_dir
47
-
48
- def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
49
- # The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
50
- if (subdir := getattr(self, subdirectory, None)) is not None:
51
- assert isinstance(subdir, Path), (
52
- f"Expected a Path for {subdirectory}, got {type(subdir)}"
53
- )
54
- return subdir
55
-
56
- dir = self.resolve_run_root_directory(run_id)
57
- dir = dir / subdirectory
58
- dir.mkdir(exist_ok=True)
59
- return dir
60
-
61
- def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
62
- if (log_dir := logger.log_dir) is not None:
63
- return log_dir
64
-
65
- # Save to {logdir_basename}/{id}/log/{logger name}
66
- log_dir = self.resolve_subdirectory(run_id, "log")
67
- log_dir = log_dir / logger.resolve_logger_dirname()
68
- # ^ NOTE: Logger must have a `name` attribute, as this is
69
- # the discriminator for the logger registry
70
- log_dir.mkdir(exist_ok=True)
71
-
72
- return log_dir
@@ -1,15 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from nshtrainer._directory import DirectoryConfig as DirectoryConfig
6
- from nshtrainer._directory import (
7
- DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
8
- )
9
- from nshtrainer._directory import LoggerConfig as LoggerConfig
10
-
11
- __all__ = [
12
- "DirectoryConfig",
13
- "DirectorySetupCallbackConfig",
14
- "LoggerConfig",
15
- ]