nshtrainer 0.11.11__tar.gz → 0.11.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/pyproject.toml +1 -1
  3. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/_checkpoint/metadata.py +5 -1
  4. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/_checkpoint/saver.py +11 -8
  5. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/checkpoint/_base.py +7 -4
  6. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -0
  7. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -0
  8. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/config.py +2 -30
  9. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/README.md +0 -0
  10. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/__init__.py +0 -0
  11. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/_checkpoint/loader.py +0 -0
  12. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/_experimental/__init__.py +0 -0
  13. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  14. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  15. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  16. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/__init__.py +0 -0
  17. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  18. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/actsave.py +0 -0
  19. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/base.py +0 -0
  20. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  21. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  22. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  23. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/ema.py +0 -0
  24. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  25. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  26. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/interval.py +0 -0
  27. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  28. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  29. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/print_table.py +0 -0
  30. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  31. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/timer.py +0 -0
  32. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  33. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/data/__init__.py +0 -0
  34. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  35. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/data/transform.py +0 -0
  36. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/__init__.py +0 -0
  37. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/_experimental.py +0 -0
  38. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/actsave.py +0 -0
  39. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/callbacks.py +0 -0
  40. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/config.py +0 -0
  41. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/data.py +0 -0
  42. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/log.py +0 -0
  43. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  44. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/model.py +0 -0
  45. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/nn.py +0 -0
  46. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/optimizer.py +0 -0
  47. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/runner.py +0 -0
  48. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/snapshot.py +0 -0
  49. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/snoop.py +0 -0
  50. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/trainer.py +0 -0
  51. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/typecheck.py +0 -0
  52. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/ll/util.py +0 -0
  53. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  54. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  55. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  56. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  57. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/metrics/__init__.py +0 -0
  58. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/metrics/_config.py +0 -0
  59. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/__init__.py +0 -0
  60. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/base.py +0 -0
  61. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/modules/callback.py +0 -0
  62. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/modules/debug.py +0 -0
  63. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/modules/distributed.py +0 -0
  64. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/modules/logger.py +0 -0
  65. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/modules/profiler.py +0 -0
  66. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  67. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  68. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/nn/__init__.py +0 -0
  69. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/nn/mlp.py +0 -0
  70. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/nn/module_dict.py +0 -0
  71. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/nn/module_list.py +0 -0
  72. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/nn/nonlinearity.py +0 -0
  73. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/optimizer.py +0 -0
  74. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/runner.py +0 -0
  75. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/trainer/trainer.py +0 -0
  81. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/util/_environment_info.py +0 -0
  82. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/util/_useful_types.py +0 -0
  83. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/util/environment.py +0 -0
  84. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/util/seed.py +0 -0
  85. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/util/slurm.py +0 -0
  86. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/util/typed.py +0 -0
  87. {nshtrainer-0.11.11 → nshtrainer-0.11.13}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.11
3
+ Version: 0.11.13
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.11"
3
+ version = "0.11.13"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -137,7 +137,11 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
137
137
  linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
138
138
  try:
139
139
  try:
140
- linked_path.symlink_to(path)
140
+ # linked_path.symlink_to(path)
141
+ # We should store the path as a relative path
142
+ # to the metadata file to avoid issues with
143
+ # moving the checkpoint directory
144
+ linked_path.symlink_to(path.relative_to(linked_path.parent))
141
145
  except OSError:
142
146
  # on Windows, special permissions are required to create symbolic links as a regular user
143
147
  # fall back to copying the file
@@ -12,22 +12,25 @@ def _link_checkpoint(
12
12
  linkpath: str | Path | os.PathLike,
13
13
  *,
14
14
  metadata: bool,
15
+ remove_existing: bool = True,
15
16
  ):
16
17
  if not isinstance(filepath, Path):
17
18
  filepath = Path(filepath)
18
19
  if not isinstance(linkpath, Path):
19
20
  linkpath = Path(linkpath)
20
21
 
21
- if linkpath.exists():
22
- if linkpath.is_symlink() or linkpath.is_file():
23
- linkpath.unlink()
24
- elif linkpath.is_dir():
25
- shutil.rmtree(linkpath)
26
- _remove_checkpoint_metadata(linkpath)
22
+ if remove_existing:
23
+ if linkpath.exists():
24
+ if linkpath.is_symlink() or linkpath.is_file():
25
+ linkpath.unlink()
26
+ elif linkpath.is_dir():
27
+ shutil.rmtree(linkpath)
28
+
29
+ if metadata:
30
+ _remove_checkpoint_metadata(linkpath)
27
31
 
28
32
  try:
29
- target_path = filepath.relative_to(linkpath.parent)
30
- linkpath.symlink_to(target_path)
33
+ linkpath.symlink_to(filepath.relative_to(linkpath.parent))
31
34
  except OSError:
32
35
  # on Windows, special permissions are required to create symbolic links as a regular user
33
36
  # fall back to copying the file
@@ -79,6 +79,9 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
79
79
  @abstractmethod
80
80
  def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
81
81
 
82
+ @abstractmethod
83
+ def topk_sort_reverse(self) -> bool: ...
84
+
82
85
  def symlink_path(self):
83
86
  if not self.config.save_symlink:
84
87
  return None
@@ -102,7 +105,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
102
105
  ]
103
106
 
104
107
  # Sort by the topk sort key
105
- metas = sorted(metas, key=self.topk_sort_key)
108
+ metas = sorted(metas, key=self.topk_sort_key, reverse=self.topk_sort_reverse())
106
109
 
107
110
  # Now, the metas are sorted from the best to the worst,
108
111
  # so we can remove the worst checkpoints
@@ -145,15 +148,15 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
145
148
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
146
149
 
147
150
  if trainer.is_global_zero:
148
- # Remove old checkpoints
149
- self.remove_old_checkpoints(trainer)
150
-
151
151
  # Create the latest symlink
152
152
  if (symlink_filename := self.symlink_path()) is not None:
153
153
  symlink_path = self.dirpath / symlink_filename
154
154
  _link_checkpoint(filepath, symlink_path, metadata=True)
155
155
  log.debug(f"Created latest symlink: {symlink_path}")
156
156
 
157
+ # Remove old checkpoints
158
+ self.remove_old_checkpoints(trainer)
159
+
157
160
  # Barrier to ensure all processes have saved the checkpoint,
158
161
  # deleted the old checkpoints, and created the symlink before continuing
159
162
  trainer.strategy.barrier()
@@ -64,6 +64,10 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
64
64
  float("-inf" if self.metric.mode == "max" else "inf"),
65
65
  )
66
66
 
67
+ @override
68
+ def topk_sort_reverse(self):
69
+ return self.metric.mode == "max"
70
+
67
71
  # Events
68
72
  @override
69
73
  def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
@@ -34,6 +34,10 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
34
34
  def topk_sort_key(self, metadata: CheckpointMetadata):
35
35
  return metadata.checkpoint_timestamp
36
36
 
37
+ @override
38
+ def topk_sort_reverse(self):
39
+ return True
40
+
37
41
  @override
38
42
  def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
39
43
  self.save_checkpoints(trainer)
@@ -3,7 +3,6 @@ import logging
3
3
  import os
4
4
  import string
5
5
  import time
6
- import warnings
7
6
  from abc import ABC, abstractmethod
8
7
  from collections.abc import Iterable, Sequence
9
8
  from datetime import timedelta
@@ -50,10 +49,6 @@ from ..util._environment_info import EnvironmentConfig
50
49
  log = logging.getLogger(__name__)
51
50
 
52
51
 
53
- class IdSeedWarning(Warning):
54
- pass
55
-
56
-
57
52
  class BaseProfilerConfig(C.Config, ABC):
58
53
  dirpath: str | Path | None = None
59
54
  """
@@ -1478,35 +1473,12 @@ class BaseConfig(C.Config):
1478
1473
  _rng: ClassVar[np.random.Generator | None] = None
1479
1474
 
1480
1475
  @staticmethod
1481
- def generate_id(
1482
- *,
1483
- length: int = 8,
1484
- ignore_rng: bool = False,
1485
- ) -> str:
1476
+ def generate_id(*, length: int = 8) -> str:
1486
1477
  """
1487
1478
  Generate a random ID of specified length.
1488
1479
 
1489
- Args:
1490
- length (int): The length of the generated ID. Default is 8.
1491
- ignore_rng (bool): If True, ignore the global random number generator and use a new one. Default is False.
1492
-
1493
- Returns:
1494
- str: The generated random ID.
1495
-
1496
- Raises:
1497
- IdSeedWarning: If the global random number generator is None and ignore_rng is False.
1498
-
1499
- Notes:
1500
- - The generated IDs will not be reproducible if the global random number generator is None and ignore_rng is False.
1501
- - To ensure reproducibility, call BaseConfig.set_seed(...) before generating any IDs.
1502
1480
  """
1503
- rng = BaseConfig._rng if not ignore_rng else np.random.default_rng()
1504
- if rng is None:
1505
- warnings.warn(
1506
- "BaseConfig._rng is None. The generated IDs will not be reproducible. "
1507
- + "To fix this, call BaseConfig.set_seed(...) before generating any IDs.",
1508
- category=IdSeedWarning,
1509
- )
1481
+ if (rng := BaseConfig._rng) is None:
1510
1482
  rng = np.random.default_rng()
1511
1483
 
1512
1484
  alphabet = list(string.ascii_lowercase + string.digits)
File without changes