nshtrainer 0.15.1__tar.gz → 0.16.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/_checkpoint/loader.py +7 -5
  4. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/_checkpoint/metadata.py +7 -11
  5. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/checkpoint/_base.py +1 -0
  6. nshtrainer-0.16.1/src/nshtrainer/ll/snapshot.py +1 -0
  7. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/trainer/trainer.py +2 -1
  8. nshtrainer-0.15.1/src/nshtrainer/ll/snapshot.py +0 -1
  9. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/README.md +0 -0
  10. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/__init__.py +0 -0
  11. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  12. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  13. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  14. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  15. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  16. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/base.py +0 -0
  17. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  18. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  19. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  20. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  21. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  22. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/ema.py +0 -0
  23. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  24. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  25. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/interval.py +0 -0
  26. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  27. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  28. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  32. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/data/__init__.py +0 -0
  33. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  34. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/data/transform.py +0 -0
  35. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/__init__.py +0 -0
  36. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/_experimental.py +0 -0
  37. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/actsave.py +0 -0
  38. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/callbacks.py +0 -0
  39. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/config.py +0 -0
  40. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/data.py +0 -0
  41. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/log.py +0 -0
  42. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  43. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/model.py +0 -0
  44. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/nn.py +0 -0
  45. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/optimizer.py +0 -0
  46. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/runner.py +0 -0
  47. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/snoop.py +0 -0
  48. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/trainer.py +0 -0
  49. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/typecheck.py +0 -0
  50. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/ll/util.py +0 -0
  51. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/loggers/__init__.py +0 -0
  52. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/loggers/_base.py +0 -0
  53. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/loggers/csv.py +0 -0
  54. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
  55. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/loggers/wandb.py +0 -0
  56. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  57. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  58. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  59. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  60. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/metrics/__init__.py +0 -0
  61. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/metrics/_config.py +0 -0
  62. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/__init__.py +0 -0
  63. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/base.py +0 -0
  64. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/config.py +0 -0
  65. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/modules/callback.py +0 -0
  66. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/modules/debug.py +0 -0
  67. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/modules/distributed.py +0 -0
  68. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/modules/logger.py +0 -0
  69. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/modules/profiler.py +0 -0
  70. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  71. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  72. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/nn/__init__.py +0 -0
  73. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/nn/mlp.py +0 -0
  74. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/nn/module_dict.py +0 -0
  75. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/nn/module_list.py +0 -0
  76. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  77. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/optimizer.py +0 -0
  78. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/runner.py +0 -0
  79. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  80. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/trainer/__init__.py +0 -0
  81. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  82. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  83. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  84. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/util/_environment_info.py +0 -0
  85. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/util/_useful_types.py +0 -0
  86. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/util/environment.py +0 -0
  87. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/util/seed.py +0 -0
  88. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/util/slurm.py +0 -0
  89. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/util/typed.py +0 -0
  90. {nshtrainer-0.15.1 → nshtrainer-0.16.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.15.1
3
+ Version: 0.16.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.15.1"
3
+ version = "0.16.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -10,7 +10,7 @@ from lightning.pytorch.trainer.states import TrainerFn
10
10
  from typing_extensions import assert_never
11
11
 
12
12
  from ..metrics._config import MetricConfig
13
- from .metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
13
+ from .metadata import CheckpointMetadata
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from ..model.config import BaseConfig
@@ -263,13 +263,13 @@ def _checkpoint_candidates(
263
263
 
264
264
  # Load all checkpoints in the directory.
265
265
  # We can do this by looking for metadata files.
266
- for path in ckpt_dir.glob(f"*{METADATA_PATH_SUFFIX}"):
266
+ for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
267
267
  if (meta := _load_ckpt_meta(path, root_config)) is not None:
268
268
  yield meta
269
269
 
270
270
  # If we have a pre-empted checkpoint, load it
271
271
  if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
272
- hpc_meta_path = Path(hpc_path).with_suffix(METADATA_PATH_SUFFIX)
272
+ hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
273
273
  if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
274
274
  yield meta
275
275
 
@@ -279,7 +279,9 @@ def _additional_candidates(
279
279
  ):
280
280
  for path in additional_candidates:
281
281
  if (
282
- meta := _load_ckpt_meta(path.with_suffix(METADATA_PATH_SUFFIX), root_config)
282
+ meta := _load_ckpt_meta(
283
+ path.with_suffix(CheckpointMetadata.PATH_SUFFIX), root_config
284
+ )
283
285
  ) is None:
284
286
  continue
285
287
  yield meta
@@ -310,7 +312,7 @@ def _resolve_checkpoint(
310
312
  match strategy:
311
313
  case UserProvidedPathCheckpointStrategyConfig():
312
314
  meta = _load_ckpt_meta(
313
- strategy.path.with_suffix(METADATA_PATH_SUFFIX),
315
+ strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
314
316
  root_config,
315
317
  on_error=strategy.on_error,
316
318
  )
@@ -10,8 +10,6 @@ import nshconfig as C
10
10
  import numpy as np
11
11
  import torch
12
12
 
13
- from ..util._environment_info import EnvironmentConfig
14
-
15
13
  if TYPE_CHECKING:
16
14
  from ..model import BaseConfig, LightningModuleBase
17
15
  from ..trainer.trainer import Trainer
@@ -38,7 +36,7 @@ class CheckpointMetadata(C.Config):
38
36
  global_step: int
39
37
  training_time: datetime.timedelta
40
38
  metrics: dict[str, Any]
41
- environment: EnvironmentConfig
39
+ environment: dict[str, Any]
42
40
 
43
41
  hparams: dict[str, Any] | None
44
42
 
@@ -48,9 +46,7 @@ class CheckpointMetadata(C.Config):
48
46
 
49
47
  @classmethod
50
48
  def from_ckpt_path(cls, checkpoint_path: Path):
51
- if not (
52
- metadata_path := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
53
- ).exists():
49
+ if not (metadata_path := checkpoint_path.with_suffix(cls.PATH_SUFFIX)).exists():
54
50
  raise FileNotFoundError(
55
51
  f"Metadata file not found for checkpoint: {checkpoint_path}"
56
52
  )
@@ -93,7 +89,7 @@ def _generate_checkpoint_metadata(
93
89
  global_step=trainer.global_step,
94
90
  training_time=training_time,
95
91
  metrics=metrics,
96
- environment=config.environment,
92
+ environment=config.environment.model_dump(mode="json"),
97
93
  hparams=config.model_dump(mode="json"),
98
94
  )
99
95
 
@@ -104,7 +100,7 @@ def _write_checkpoint_metadata(
104
100
  checkpoint_path: Path,
105
101
  ):
106
102
  config = cast("BaseConfig", model.config)
107
- metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
103
+ metadata_path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
108
104
  metadata = _generate_checkpoint_metadata(
109
105
  config, trainer, checkpoint_path, metadata_path
110
106
  )
@@ -119,7 +115,7 @@ def _write_checkpoint_metadata(
119
115
 
120
116
 
121
117
  def _remove_checkpoint_metadata(checkpoint_path: Path):
122
- path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
118
+ path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
123
119
  try:
124
120
  path.unlink(missing_ok=True)
125
121
  except Exception as e:
@@ -133,8 +129,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
133
129
  _remove_checkpoint_metadata(linked_checkpoint_path)
134
130
 
135
131
  # Link the metadata files to the new checkpoint
136
- path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
137
- linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
132
+ path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
133
+ linked_path = linked_checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
138
134
  try:
139
135
  try:
140
136
  # linked_path.symlink_to(path)
@@ -102,6 +102,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
102
102
  metas = [
103
103
  CheckpointMetadata.from_file(p)
104
104
  for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
105
+ if p.is_file() and not p.is_symlink()
105
106
  ]
106
107
 
107
108
  # Sort by the topk sort key
@@ -0,0 +1 @@
1
+ from nshsnap import * # pyright: ignore[reportWildcardImportFromLibrary] # noqa: F403
@@ -419,7 +419,8 @@ class Trainer(LightningTrainer):
419
419
 
420
420
  # Save the checkpoint metadata
421
421
  lm = self._base_module
422
- if lm.config.trainer.save_checkpoint_metadata and self.is_global_zero:
422
+ hparams = cast(BaseConfig, lm.hparams)
423
+ if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
423
424
  # Generate the metadata and write to disk
424
425
  _write_checkpoint_metadata(self, lm, filepath)
425
426
 
@@ -1 +0,0 @@
1
- from nshrunner.snapshot import * # type: ignore # noqa: F403
File without changes