nshtrainer 0.27.0__tar.gz → 0.28.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/_hf_hub.py +0 -16
  4. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/checkpoint/_base.py +1 -3
  5. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/config.py +0 -2
  6. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/trainer/trainer.py +5 -57
  7. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/README.md +0 -0
  8. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/__init__.py +0 -0
  9. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/_callback.py +0 -0
  10. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  11. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  12. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  13. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  14. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  15. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  16. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  17. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/base.py +0 -0
  18. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  19. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  20. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  21. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  22. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  23. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/ema.py +0 -0
  24. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  25. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  26. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/interval.py +0 -0
  27. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  28. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  29. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  30. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  31. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/timer.py +0 -0
  32. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  33. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/data/__init__.py +0 -0
  34. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  35. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/data/transform.py +0 -0
  36. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/__init__.py +0 -0
  37. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/_experimental.py +0 -0
  38. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/actsave.py +0 -0
  39. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/callbacks.py +0 -0
  40. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/config.py +0 -0
  41. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/data.py +0 -0
  42. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/log.py +0 -0
  43. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  44. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/model.py +0 -0
  45. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/nn.py +0 -0
  46. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/optimizer.py +0 -0
  47. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/runner.py +0 -0
  48. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/snapshot.py +0 -0
  49. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/snoop.py +0 -0
  50. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/trainer.py +0 -0
  51. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/typecheck.py +0 -0
  52. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/ll/util.py +0 -0
  53. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/loggers/__init__.py +0 -0
  54. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/loggers/_base.py +0 -0
  55. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/loggers/csv.py +0 -0
  56. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  57. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/loggers/wandb.py +0 -0
  58. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  59. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  60. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  61. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  62. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/metrics/__init__.py +0 -0
  63. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/metrics/_config.py +0 -0
  64. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/__init__.py +0 -0
  65. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/base.py +0 -0
  66. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/modules/callback.py +0 -0
  67. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/modules/debug.py +0 -0
  68. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/modules/distributed.py +0 -0
  69. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/modules/logger.py +0 -0
  70. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/modules/profiler.py +0 -0
  71. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  72. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  73. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/nn/__init__.py +0 -0
  74. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/nn/mlp.py +0 -0
  75. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/nn/module_dict.py +0 -0
  76. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/nn/module_list.py +0 -0
  77. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  78. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/optimizer.py +0 -0
  79. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/runner.py +0 -0
  80. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  81. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/trainer/__init__.py +0 -0
  82. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  83. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  84. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  85. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/_environment_info.py +0 -0
  86. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/_useful_types.py +0 -0
  87. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/environment.py +0 -0
  88. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/path.py +0 -0
  89. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/seed.py +0 -0
  90. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/slurm.py +0 -0
  91. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/typed.py +0 -0
  92. {nshtrainer-0.27.0 → nshtrainer-0.28.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.27.0
3
+ Version: 0.28.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.27.0"
3
+ version = "0.28.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -359,19 +359,3 @@ class HFHubCallback(NTCallbackBase):
359
359
  # NOTE: This file is fairly small, so we can just upload it directly.
360
360
  # No need to copy.
361
361
  self._save_file(metadata_path)
362
-
363
- @override
364
- def state_dict(self):
365
- return {
366
- "repo_id": self._repo_id,
367
- "checksum_to_path_in_repo": {
368
- k: str(v) for k, v in self._checksum_to_path_in_repo.items()
369
- },
370
- }
371
-
372
- @override
373
- def load_state_dict(self, state_dict):
374
- self._repo_id = state_dict["repo_id"]
375
- self._checksum_to_path_in_repo = {
376
- k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
377
- }
@@ -152,9 +152,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
152
152
 
153
153
  # Save the new checkpoint
154
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
155
- trainer._nshtrainer_save_checkpoint(
156
- filepath, self.config.save_weights_only, use_checkpoint_cache=None
157
- )
155
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
158
156
 
159
157
  if trainer.is_global_zero:
160
158
  # Create the latest symlink
@@ -1012,8 +1012,6 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
- use_checkpoint_cache: bool = False
1016
- """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1017
1015
 
1018
1016
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1019
1017
  """
@@ -1,7 +1,5 @@
1
1
  import logging
2
2
  import os
3
- import shutil
4
- from collections import defaultdict
5
3
  from collections.abc import Sequence
6
4
  from pathlib import Path
7
5
  from typing import TYPE_CHECKING, Any, cast
@@ -280,12 +278,6 @@ class Trainer(LightningTrainer):
280
278
  if TYPE_CHECKING:
281
279
  callbacks: list[Callback]
282
280
 
283
- def _nshtrainer_checkpoint_cache_get(self, key: tuple[int, int]):
284
- return next(
285
- (ckpt for ckpt in self._nshtrainer_checkpoint_cache[key] if ckpt.exists()),
286
- None,
287
- )
288
-
289
281
  @override
290
282
  def __init__(
291
283
  self,
@@ -293,10 +285,6 @@ class Trainer(LightningTrainer):
293
285
  /,
294
286
  **kwargs: Unpack[LightningTrainerKwargs],
295
287
  ):
296
- self._nshtrainer_checkpoint_cache = defaultdict[tuple[int, int], list[Path]](
297
- lambda: []
298
- )
299
-
300
288
  self._pre_init(config)
301
289
 
302
290
  kwargs = self._update_kwargs(config, kwargs)
@@ -419,50 +407,24 @@ class Trainer(LightningTrainer):
419
407
 
420
408
  return super()._run(model, ckpt_path)
421
409
 
422
- def _nshtrainer_save_checkpoint(
410
+ @override
411
+ def save_checkpoint(
423
412
  self,
424
413
  filepath: str | Path,
425
414
  weights_only: bool = False,
426
415
  storage_options: Any | None = None,
427
- use_checkpoint_cache: bool | None = None,
428
416
  ):
429
- lm = self._base_module
430
- root_config = cast(BaseConfig, lm.hparams)
431
- if use_checkpoint_cache is None:
432
- use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
433
-
434
417
  filepath = Path(filepath)
435
418
 
436
419
  # List of files that we should upload to HF
437
420
  written_files: list[Path] = [filepath]
438
421
 
439
- cached_path = None
440
- if (
441
- use_checkpoint_cache
442
- and (
443
- cached_path := self._nshtrainer_checkpoint_cache_get(
444
- (self.current_epoch, self.global_step)
445
- )
446
- )
447
- is not None
448
- ):
449
- # If we have a cached path, then we symlink it to the new path.
450
- log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
451
- if self.is_global_zero:
452
- shutil.copy(cached_path, filepath)
453
- self.strategy.barrier("Trainer.save_checkpoint")
454
- else:
455
- super().save_checkpoint(filepath, weights_only, storage_options)
456
-
457
- # If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
458
- if use_checkpoint_cache and cached_path is None:
459
- self._nshtrainer_checkpoint_cache[
460
- (self.current_epoch, self.global_step)
461
- ].append(filepath)
462
- log.debug(f"Checkpoint saved to cache: {filepath}")
422
+ super().save_checkpoint(filepath, weights_only, storage_options)
463
423
 
464
424
  # Save the checkpoint metadata
465
425
  metadata_path = None
426
+ lm = self._base_module
427
+ root_config = cast(BaseConfig, lm.hparams)
466
428
  if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
467
429
  # Generate the metadata and write to disk
468
430
  if (
@@ -474,17 +436,3 @@ class Trainer(LightningTrainer):
474
436
  from .. import _callback
475
437
 
476
438
  _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
477
-
478
- @override
479
- def save_checkpoint(
480
- self,
481
- filepath: str | Path,
482
- weights_only: bool = False,
483
- storage_options: Any | None = None,
484
- ):
485
- return self._nshtrainer_save_checkpoint(
486
- filepath=filepath,
487
- weights_only=weights_only,
488
- storage_options=storage_options,
489
- use_checkpoint_cache=False,
490
- )
File without changes