nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -1,6 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import logging
5
+ import os
6
+ import string
7
+ import time
4
8
  from collections.abc import Iterable, Sequence
5
9
  from datetime import timedelta
6
10
  from pathlib import Path
@@ -8,6 +12,7 @@ from typing import (
8
12
  TYPE_CHECKING,
9
13
  Annotated,
10
14
  Any,
15
+ ClassVar,
11
16
  Literal,
12
17
  Protocol,
13
18
  TypeAlias,
@@ -15,6 +20,7 @@ from typing import (
15
20
  )
16
21
 
17
22
  import nshconfig as C
23
+ import numpy as np
18
24
  from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
19
25
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
20
26
  from lightning.pytorch.accelerators import Accelerator
@@ -28,6 +34,7 @@ from lightning.pytorch.strategies.strategy import Strategy
28
34
  from typing_extensions import TypedDict, TypeVar, override
29
35
 
30
36
  from .._checkpoint.loader import CheckpointLoadingConfig
37
+ from .._directory import DirectoryConfig
31
38
  from .._hf_hub import HuggingFaceHubConfig
32
39
  from ..callbacks import (
33
40
  BestCheckpointCallbackConfig,
@@ -47,10 +54,10 @@ from ..loggers import (
47
54
  TensorboardLoggerConfig,
48
55
  WandbLoggerConfig,
49
56
  )
57
+ from ..loggers.actsave import ActSaveLoggerConfig
58
+ from ..metrics._config import MetricConfig
50
59
  from ..profiler import ProfilerConfig
51
-
52
- if TYPE_CHECKING:
53
- from ..model.config import BaseConfig
60
+ from ..util._environment_info import EnvironmentConfig
54
61
 
55
62
  log = logging.getLogger(__name__)
56
63
 
@@ -71,7 +78,7 @@ class LoggingConfig(CallbackConfigBase):
71
78
  log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
72
79
  """If enabled, will log the fractional epoch number to the logger."""
73
80
 
74
- actsave_logged_metrics: bool = False
81
+ actsave_logger: ActSaveLoggerConfig | None = None
75
82
  """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
76
83
 
77
84
  @property
@@ -103,12 +110,12 @@ class LoggingConfig(CallbackConfigBase):
103
110
  None,
104
111
  )
105
112
 
106
- def create_loggers(self, root_config: "BaseConfig"):
113
+ def create_loggers(self, trainer_config: TrainerConfig):
107
114
  """
108
115
  Constructs and returns a list of loggers based on the provided root configuration.
109
116
 
110
117
  Args:
111
- root_config (BaseConfig): The root configuration object.
118
+ trainer_config (TrainerConfig): The root configuration object.
112
119
 
113
120
  Returns:
114
121
  list[Logger]: A list of constructed loggers.
@@ -123,12 +130,16 @@ class LoggingConfig(CallbackConfigBase):
123
130
  ):
124
131
  if not logger_config.enabled:
125
132
  continue
126
- if (logger := logger_config.create_logger(root_config)) is None:
133
+ if (logger := logger_config.create_logger(trainer_config)) is None:
127
134
  continue
128
135
  yield logger
129
136
 
137
+ # If the actsave_metrics is enabled, add the ActSave logger
138
+ if self.actsave_logger:
139
+ yield self.actsave_logger.create_logger(trainer_config)
140
+
130
141
  @override
131
- def create_callbacks(self, root_config):
142
+ def create_callbacks(self, trainer_config):
132
143
  if self.log_lr:
133
144
  from lightning.pytorch.callbacks import LearningRateMonitor
134
145
 
@@ -139,13 +150,13 @@ class LoggingConfig(CallbackConfigBase):
139
150
  yield LearningRateMonitor(logging_interval=logging_interval)
140
151
 
141
152
  if self.log_epoch:
142
- yield from self.log_epoch.create_callbacks(root_config)
153
+ yield from self.log_epoch.create_callbacks(trainer_config)
143
154
 
144
155
  for logger in self.loggers:
145
156
  if not logger or not isinstance(logger, CallbackConfigBase):
146
157
  continue
147
158
 
148
- yield from logger.create_callbacks(root_config)
159
+ yield from logger.create_callbacks(trainer_config)
149
160
 
150
161
 
151
162
  class GradientClippingConfig(C.Config):
@@ -172,7 +183,7 @@ class OptimizationConfig(CallbackConfigBase):
172
183
  """Gradient clipping configuration, or None to disable gradient clipping."""
173
184
 
174
185
  @override
175
- def create_callbacks(self, root_config):
186
+ def create_callbacks(self, trainer_config):
176
187
  from ..callbacks.norm_logging import NormLoggingCallbackConfig
177
188
 
178
189
  yield from NormLoggingCallbackConfig(
@@ -180,7 +191,7 @@ class OptimizationConfig(CallbackConfigBase):
180
191
  log_grad_norm_per_param=self.log_grad_norm_per_param,
181
192
  log_param_norm=self.log_param_norm,
182
193
  log_param_norm_per_param=self.log_param_norm_per_param,
183
- ).create_callbacks(root_config)
194
+ ).create_callbacks(trainer_config)
184
195
 
185
196
 
186
197
  TPlugin = TypeVar(
@@ -274,22 +285,22 @@ class CheckpointSavingConfig(CallbackConfigBase):
274
285
  self.enabled = False
275
286
  return self
276
287
 
277
- def should_save_checkpoints(self, root_config: "BaseConfig"):
288
+ def should_save_checkpoints(self, trainer_config: TrainerConfig):
278
289
  if not self.enabled:
279
290
  return False
280
291
 
281
- if root_config.trainer.fast_dev_run:
292
+ if trainer_config.fast_dev_run:
282
293
  return False
283
294
 
284
295
  return True
285
296
 
286
297
  @override
287
- def create_callbacks(self, root_config: "BaseConfig"):
288
- if not self.should_save_checkpoints(root_config):
298
+ def create_callbacks(self, trainer_config: TrainerConfig):
299
+ if not self.should_save_checkpoints(trainer_config):
289
300
  return
290
301
 
291
302
  for callback_config in self.checkpoint_callbacks:
292
- yield from callback_config.create_callbacks(root_config)
303
+ yield from callback_config.create_callbacks(trainer_config)
293
304
 
294
305
 
295
306
  class LightningTrainerKwargs(TypedDict, total=False):
@@ -541,6 +552,74 @@ class SanityCheckingConfig(C.Config):
541
552
 
542
553
 
543
554
  class TrainerConfig(C.Config):
555
+ # region Active Run Configuration
556
+ id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
557
+ """ID of the run."""
558
+ name: list[str] = []
559
+ """Run name in parts. Full name is constructed by joining the parts with spaces."""
560
+ project: str | None = None
561
+ """Project name."""
562
+ tags: list[str] = []
563
+ """Tags for the run."""
564
+ notes: list[str] = []
565
+ """Human readable notes for the run."""
566
+
567
+ @property
568
+ def full_name(self):
569
+ return " ".join(self.name)
570
+
571
+ debug: bool = False
572
+ """Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
573
+
574
+ environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
575
+ EnvironmentConfig.empty()
576
+ )
577
+ """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
578
+
579
+ directory: DirectoryConfig = DirectoryConfig()
580
+ """Directory configuration options."""
581
+
582
+ _rng: ClassVar[np.random.Generator | None] = None
583
+
584
+ @classmethod
585
+ def generate_id(cls, *, length: int = 8) -> str:
586
+ """
587
+ Generate a random ID of specified length.
588
+
589
+ """
590
+ if (rng := cls._rng) is None:
591
+ rng = np.random.default_rng()
592
+
593
+ alphabet = list(string.ascii_lowercase + string.digits)
594
+
595
+ id = "".join(rng.choice(alphabet) for _ in range(length))
596
+ return id
597
+
598
+ @classmethod
599
+ def set_seed(cls, seed: int | None = None) -> None:
600
+ """
601
+ Set the seed for the random number generator.
602
+
603
+ Args:
604
+ seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
605
+
606
+ Returns:
607
+ None
608
+ """
609
+ if seed is None:
610
+ seed = int(time.time() * 1000)
611
+ log.critical(f"Seeding {cls.__name__} with seed {seed}")
612
+ cls._rng = np.random.default_rng(seed)
613
+
614
+ # endregion
615
+
616
+ primary_metric: MetricConfig | None = None
617
+ """Primary metric configuration options. This is used in the following ways:
618
+ - To determine the best model checkpoint to save with the ModelCheckpoint callback.
619
+ - To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
620
+ - For the ReduceLROnPlateau scheduler.
621
+ """
622
+
544
623
  ckpt_path: Literal["none"] | str | Path | None = None
545
624
  """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
546
625
 
@@ -788,3 +867,71 @@ class TrainerConfig(C.Config):
788
867
  yield self.reduce_lr_on_plateau_sanity_checking
789
868
  yield self.auto_set_debug_flag
790
869
  yield from self.callbacks
870
+
871
+ # region Helper Methods
872
+ def with_fast_dev_run(self, value: int | bool = True, /):
873
+ """
874
+ Enables fast_dev_run mode for the trainer.
875
+ This will run the training loop for a specified number of batches,
876
+ if an integer is provided, or for a single batch if True is provided.
877
+ """
878
+ config = copy.deepcopy(self)
879
+ config.fast_dev_run = value
880
+ return config
881
+
882
+ def with_project_root(self, project_root: str | Path | os.PathLike):
883
+ """
884
+ Set the project root directory for the trainer.
885
+
886
+ Args:
887
+ project_root (Path): The base directory to use.
888
+
889
+ Returns:
890
+ self: The current instance of the class.
891
+ """
892
+ config = copy.deepcopy(self)
893
+ config.directory.project_root = Path(project_root)
894
+ return config
895
+
896
+ def reset_run(
897
+ self,
898
+ *,
899
+ id: bool = True,
900
+ basic: bool = True,
901
+ project_root: bool = True,
902
+ environment: bool = True,
903
+ ):
904
+ """
905
+ Reset the configuration object to its initial state.
906
+
907
+ Parameters:
908
+ - id (bool): If True, generate a new ID for the configuration object.
909
+ - basic (bool): If True, reset basic attributes like name, project, tags, and notes.
910
+ - project_root (bool): If True, reset the directory configuration to its initial state.
911
+ - environment (bool): If True, reset the environment configuration to its initial state.
912
+ - meta (bool): If True, reset the meta dictionary to an empty dictionary.
913
+
914
+ Returns:
915
+ - self: The updated configuration object.
916
+
917
+ """
918
+ config = copy.deepcopy(self)
919
+
920
+ if id:
921
+ config.id = config.generate_id()
922
+
923
+ if basic:
924
+ config.name = []
925
+ config.project = None
926
+ config.tags = []
927
+ config.notes = []
928
+
929
+ if project_root:
930
+ config.directory = DirectoryConfig()
931
+
932
+ if environment:
933
+ config.environment = EnvironmentConfig.empty()
934
+
935
+ return config
936
+
937
+ # endregion
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from pathlib import Path
5
- from typing import TYPE_CHECKING, cast
6
5
 
7
6
  from lightning.pytorch.trainer.connectors.checkpoint_connector import (
8
7
  _CheckpointConnector as _LightningCheckpointConnector,
@@ -12,8 +11,6 @@ from typing_extensions import override
12
11
 
13
12
  from .._checkpoint.loader import CheckpointLoadingConfig, _resolve_checkpoint
14
13
 
15
- if TYPE_CHECKING:
16
- from ..model.config import BaseConfig
17
14
  log = logging.getLogger(__name__)
18
15
 
19
16
 
@@ -32,8 +29,7 @@ class _CheckpointConnector(_LightningCheckpointConnector):
32
29
  return None
33
30
 
34
31
  # Now, resolve the checkpoint loader config.
35
- root_config = cast("BaseConfig", trainer._base_module.config)
36
- ckpt_loader_config = root_config.trainer.checkpoint_loading
32
+ ckpt_loader_config = trainer.hparams.checkpoint_loading
37
33
  match ckpt_loader_config:
38
34
  case "auto":
39
35
  ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
@@ -44,9 +40,7 @@ class _CheckpointConnector(_LightningCheckpointConnector):
44
40
  log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
45
41
 
46
42
  # Use the config to resolve the checkpoint.
47
- if (
48
- ckpt_path := _resolve_checkpoint(ckpt_loader_config, root_config, trainer)
49
- ) is None:
43
+ if (ckpt_path := _resolve_checkpoint(ckpt_loader_config, trainer)) is None:
50
44
  log.info(
51
45
  "No checkpoint found for the current trainer state. "
52
46
  "Training will start from scratch."
@@ -69,3 +63,24 @@ class _CheckpointConnector(_LightningCheckpointConnector):
69
63
  return super()._parse_ckpt_path(
70
64
  state_fn, ckpt_path, model_provided, model_connected
71
65
  )
66
+
67
+ @override
68
+ def dump_checkpoint(self, weights_only: bool = False):
69
+ checkpoint = super().dump_checkpoint(weights_only)
70
+
71
+ # Save the trainer's config.
72
+ _add_trainer_config_to_checkpoint_(checkpoint, self.trainer)
73
+
74
+ return checkpoint
75
+
76
+
77
+ def _add_trainer_config_to_checkpoint_(checkpoint: dict, trainer):
78
+ from .trainer import Trainer
79
+
80
+ # If this isn't an `nshtrainer` trainer (which I don't know why it wouldn't be),
81
+ # then we just return.
82
+ if isinstance(trainer, Trainer):
83
+ return None
84
+
85
+ # Save the trainer's config.
86
+ checkpoint[trainer.CHECKPOINT_HYPER_PARAMS_KEY] = dict(trainer.hparams)