nshtrainer 0.19.3__tar.gz → 0.20.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/_checkpoint/loader.py +12 -4
  4. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/config.py +7 -4
  5. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/trainer/checkpoint_connector.py +8 -2
  6. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/README.md +0 -0
  7. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/__init__.py +0 -0
  8. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  9. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  10. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  11. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/_hf_hub.py +0 -0
  12. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  13. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  14. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  15. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/base.py +0 -0
  16. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  17. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  18. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  19. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  20. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  21. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  22. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/ema.py +0 -0
  23. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  24. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  25. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/interval.py +0 -0
  26. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  27. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  28. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  32. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/data/__init__.py +0 -0
  33. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  34. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/data/transform.py +0 -0
  35. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/__init__.py +0 -0
  36. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/_experimental.py +0 -0
  37. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/actsave.py +0 -0
  38. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/callbacks.py +0 -0
  39. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/config.py +0 -0
  40. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/data.py +0 -0
  41. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/log.py +0 -0
  42. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  43. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/model.py +0 -0
  44. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/nn.py +0 -0
  45. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/optimizer.py +0 -0
  46. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/runner.py +0 -0
  47. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/snapshot.py +0 -0
  48. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/snoop.py +0 -0
  49. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/trainer.py +0 -0
  50. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/typecheck.py +0 -0
  51. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/ll/util.py +0 -0
  52. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/loggers/__init__.py +0 -0
  53. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/loggers/_base.py +0 -0
  54. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/loggers/csv.py +0 -0
  55. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  56. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/loggers/wandb.py +0 -0
  57. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  58. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  59. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  60. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  61. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/metrics/__init__.py +0 -0
  62. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/metrics/_config.py +0 -0
  63. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/__init__.py +0 -0
  64. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/base.py +0 -0
  65. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/modules/callback.py +0 -0
  66. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/modules/debug.py +0 -0
  67. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/modules/distributed.py +0 -0
  68. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/modules/logger.py +0 -0
  69. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/modules/profiler.py +0 -0
  70. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  71. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  72. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/nn/__init__.py +0 -0
  73. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/nn/mlp.py +0 -0
  74. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/nn/module_dict.py +0 -0
  75. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/nn/module_list.py +0 -0
  76. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  77. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/optimizer.py +0 -0
  78. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/runner.py +0 -0
  79. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  80. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/trainer/__init__.py +0 -0
  81. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  82. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  83. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/trainer/trainer.py +0 -0
  84. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/util/_environment_info.py +0 -0
  85. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/util/_useful_types.py +0 -0
  86. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/util/environment.py +0 -0
  87. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/util/seed.py +0 -0
  88. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/util/slurm.py +0 -0
  89. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/util/typed.py +0 -0
  90. {nshtrainer-0.19.3 → nshtrainer-0.20.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.19.3
3
+ Version: 0.20.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.19.3"
3
+ version = "0.20.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -76,7 +76,11 @@ class CheckpointLoadingConfig(C.Config):
76
76
  """Whether to include checkpoints from HPC pre-emption."""
77
77
 
78
78
  @classmethod
79
- def _auto_train(cls, ckpt: Literal["best", "last"] | str | Path | None):
79
+ def none(cls, include_hpc: bool = False):
80
+ return cls(strategies=[], include_hpc=include_hpc)
81
+
82
+ @classmethod
83
+ def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
80
84
  if ckpt is None:
81
85
  ckpt = "last"
82
86
  match ckpt:
@@ -90,6 +94,8 @@ class CheckpointLoadingConfig(C.Config):
90
94
  strategies=[LastCheckpointStrategyConfig()],
91
95
  include_hpc=True,
92
96
  )
97
+ case "none":
98
+ return cls.none()
93
99
  case Path() | str():
94
100
  ckpt = Path(ckpt)
95
101
  return cls(
@@ -103,7 +109,7 @@ class CheckpointLoadingConfig(C.Config):
103
109
  assert_never(ckpt)
104
110
 
105
111
  @classmethod
106
- def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
112
+ def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
107
113
  if ckpt is None:
108
114
  log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
109
115
  ckpt = "last"
@@ -119,6 +125,8 @@ class CheckpointLoadingConfig(C.Config):
119
125
  strategies=[LastCheckpointStrategyConfig()],
120
126
  include_hpc=False,
121
127
  )
128
+ case "none":
129
+ return cls.none(include_hpc=False)
122
130
  case Path() | str():
123
131
  ckpt = Path(ckpt)
124
132
  return cls(
@@ -131,7 +139,7 @@ class CheckpointLoadingConfig(C.Config):
131
139
  @classmethod
132
140
  def auto(
133
141
  cls,
134
- ckpt: Literal["best", "last"] | str | Path | None,
142
+ ckpt: Literal["best", "last", "none"] | str | Path | None,
135
143
  trainer_mode: TrainerFn,
136
144
  ):
137
145
  """
@@ -142,7 +150,7 @@ class CheckpointLoadingConfig(C.Config):
142
150
 
143
151
  Parameters:
144
152
  -----------
145
- ckpt : Literal["best", "last"] | str | Path | None
153
+ ckpt : Literal["best", "last", "none"] | str | Path | None
146
154
  Specifies the checkpoint loading preference:
147
155
  - "best": Use the best checkpoint based on the primary metric.
148
156
  - "last": Use the most recent checkpoint.
@@ -811,11 +811,14 @@ class SanityCheckingConfig(C.Config):
811
811
 
812
812
 
813
813
  class TrainerConfig(C.Config):
814
- ckpt_path: str | Path | None = None
815
- """Path to a checkpoint to load and resume training from."""
814
+ ckpt_path: Literal["none"] | str | Path | None = None
815
+ """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
816
816
 
817
- checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
818
- """Checkpoint loading configuration options."""
817
+ checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
818
+ """Checkpoint loading configuration options.
819
+ `"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
820
+ `"none"` will disable checkpoint loading.
821
+ """
819
822
 
820
823
  checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
821
824
  """Checkpoint saving configuration options."""
@@ -31,8 +31,14 @@ class _CheckpointConnector(_LightningCheckpointConnector):
31
31
 
32
32
  # Now, resolve the checkpoint loader config.
33
33
  root_config = cast("BaseConfig", trainer._base_module.config)
34
- if (ckpt_loader_config := root_config.trainer.checkpoint_loading) == "auto":
35
- ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
34
+ ckpt_loader_config = root_config.trainer.checkpoint_loading
35
+ match ckpt_loader_config:
36
+ case "auto":
37
+ ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
38
+ case "none":
39
+ ckpt_loader_config = CheckpointLoadingConfig.none()
40
+ case _:
41
+ pass
36
42
  log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
37
43
 
38
44
  # Use the config to resolve the checkpoint.
File without changes