nshtrainer 0.21.0__tar.gz → 0.22.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/PKG-INFO +1 -2
  2. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/pyproject.toml +1 -3
  3. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_checkpoint/metadata.py +2 -1
  4. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_hf_hub.py +2 -9
  5. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/base.py +22 -19
  6. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/trainer.py +2 -1
  7. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/README.md +0 -0
  8. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/__init__.py +0 -0
  9. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_callback.py +0 -0
  10. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
  11. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  12. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  13. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  14. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  15. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  16. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  17. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  18. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  19. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  20. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  21. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  22. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/ema.py +0 -0
  23. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  24. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  25. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/interval.py +0 -0
  26. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  27. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  28. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  32. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/data/__init__.py +0 -0
  33. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  34. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/data/transform.py +0 -0
  35. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/__init__.py +0 -0
  36. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/_experimental.py +0 -0
  37. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/actsave.py +0 -0
  38. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/callbacks.py +0 -0
  39. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/config.py +0 -0
  40. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/data.py +0 -0
  41. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/log.py +0 -0
  42. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  43. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/model.py +0 -0
  44. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/nn.py +0 -0
  45. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/optimizer.py +0 -0
  46. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/runner.py +0 -0
  47. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/snapshot.py +0 -0
  48. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/snoop.py +0 -0
  49. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/trainer.py +0 -0
  50. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/typecheck.py +0 -0
  51. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/ll/util.py +0 -0
  52. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/__init__.py +0 -0
  53. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/_base.py +0 -0
  54. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/csv.py +0 -0
  55. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
  56. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/loggers/wandb.py +0 -0
  57. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  58. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  59. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  60. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  61. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/metrics/__init__.py +0 -0
  62. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/metrics/_config.py +0 -0
  63. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/__init__.py +0 -0
  64. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/base.py +0 -0
  65. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/config.py +0 -0
  66. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/callback.py +0 -0
  67. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/debug.py +0 -0
  68. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/distributed.py +0 -0
  69. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/logger.py +0 -0
  70. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/profiler.py +0 -0
  71. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  72. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  73. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/__init__.py +0 -0
  74. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/mlp.py +0 -0
  75. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/module_dict.py +0 -0
  76. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/module_list.py +0 -0
  77. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  78. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/optimizer.py +0 -0
  79. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/runner.py +0 -0
  80. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  81. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/__init__.py +0 -0
  82. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  83. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  84. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  85. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/_environment_info.py +0 -0
  86. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/_useful_types.py +0 -0
  87. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/environment.py +0 -0
  88. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/path.py +0 -0
  89. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/seed.py +0 -0
  90. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/slurm.py +0 -0
  91. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/typed.py +0 -0
  92. {nshtrainer-0.21.0 → nshtrainer-0.22.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.21.0
3
+ Version: 0.22.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -26,7 +26,6 @@ Requires-Dist: torchmetrics ; extra == "extra"
26
26
  Requires-Dist: typing-extensions
27
27
  Requires-Dist: wandb ; extra == "extra"
28
28
  Requires-Dist: wrapt ; extra == "extra"
29
- Requires-Dist: zstandard ; extra == "extra"
30
29
  Description-Content-Type: text/markdown
31
30
 
32
31
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.21.0"
3
+ version = "0.22.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -23,7 +23,6 @@ GitPython = { version = "*", optional = true }
23
23
  wandb = { version = "*", optional = true }
24
24
  tensorboard = { version = "*", optional = true }
25
25
  huggingface-hub = { version = "*", optional = true }
26
- zstandard = { version = "*", optional = true }
27
26
 
28
27
  [tool.poetry.group.dev.dependencies]
29
28
  pyright = "^1.1.372"
@@ -54,5 +53,4 @@ extra = [
54
53
  "wandb",
55
54
  "tensorboard",
56
55
  "huggingface-hub",
57
- "zstandard",
58
56
  ]
@@ -11,6 +11,7 @@ import numpy as np
11
11
  import torch
12
12
 
13
13
  from ..util._environment_info import EnvironmentConfig
14
+ from ..util.path import get_relative_path
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  from ..model import BaseConfig, LightningModuleBase
@@ -145,7 +146,7 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
145
146
  # We should store the path as a relative path
146
147
  # to the metadata file to avoid issues with
147
148
  # moving the checkpoint directory
148
- linked_path.symlink_to(path.relative_to(linked_path.parent))
149
+ linked_path.symlink_to(get_relative_path(linked_path, path))
149
150
  except OSError:
150
151
  # on Windows, special permissions are required to create symbolic links as a regular user
151
152
  # fall back to copying the file
@@ -10,11 +10,7 @@ from nshrunner._env import SNAPSHOT_DIR
10
10
  from typing_extensions import override
11
11
 
12
12
  from ._callback import NTCallbackBase
13
- from .callbacks.base import (
14
- CallbackConfigBase,
15
- CallbackMetadataConfig,
16
- CallbackWithMetadata,
17
- )
13
+ from .callbacks.base import CallbackConfigBase
18
14
 
19
15
  if TYPE_CHECKING:
20
16
  from huggingface_hub import HfApi # noqa: F401
@@ -81,10 +77,7 @@ class HuggingFaceHubConfig(CallbackConfigBase):
81
77
 
82
78
  @override
83
79
  def create_callbacks(self, root_config):
84
- yield CallbackWithMetadata(
85
- HFHubCallback(self),
86
- CallbackMetadataConfig(ignore_if_exists=True),
87
- )
80
+ yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
88
81
 
89
82
 
90
83
  def _api(token: str | None = None):
@@ -2,29 +2,24 @@ from abc import ABC, abstractmethod
2
2
  from collections import Counter
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, TypeAlias, TypedDict
5
+ from typing import TYPE_CHECKING, ClassVar, TypeAlias
6
6
 
7
7
  import nshconfig as C
8
8
  from lightning.pytorch import Callback
9
+ from typing_extensions import TypedDict, Unpack
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from ..model.config import BaseConfig
12
13
 
13
14
 
14
- class CallbackMetadataDict(TypedDict, total=False):
15
+ class CallbackMetadataConfig(TypedDict, total=False):
15
16
  ignore_if_exists: bool
16
- """If `True`, the callback will not be added if another callback with the same class already exists."""
17
+ """If `True`, the callback will not be added if another callback with the same class already exists.
18
+ Default is `False`."""
17
19
 
18
20
  priority: int
19
- """Priority of the callback. Callbacks with higher priority will be loaded first."""
20
-
21
-
22
- class CallbackMetadataConfig(C.Config):
23
- ignore_if_exists: bool = False
24
- """If `True`, the callback will not be added if another callback with the same class already exists."""
25
-
26
- priority: int = 0
27
- """Priority of the callback. Callbacks with higher priority will be loaded first."""
21
+ """Priority of the callback. Callbacks with higher priority will be loaded first.
22
+ Default is `0`."""
28
23
 
29
24
 
30
25
  @dataclass(frozen=True)
@@ -37,13 +32,18 @@ ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
37
32
 
38
33
 
39
34
  class CallbackConfigBase(C.Config, ABC):
40
- metadata: CallbackMetadataConfig = CallbackMetadataConfig()
35
+ metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig()
41
36
  """Metadata for the callback."""
42
37
 
43
- def with_metadata(self, callback: Callback, **metadata: CallbackMetadataDict):
44
- return CallbackWithMetadata(
45
- callback=callback, metadata=self.metadata.model_copy(update=metadata)
46
- )
38
+ @classmethod
39
+ def with_metadata(
40
+ cls, callback: Callback, **kwargs: Unpack[CallbackMetadataConfig]
41
+ ):
42
+ metadata: CallbackMetadataConfig = {}
43
+ metadata.update(cls.metadata)
44
+ metadata.update(kwargs)
45
+
46
+ return CallbackWithMetadata(callback=callback, metadata=metadata)
47
47
 
48
48
  @abstractmethod
49
49
  def create_callbacks(
@@ -73,7 +73,7 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
73
73
  for callback in callbacks:
74
74
  # If `ignore_if_exists` is `True` and there is already a callback of the same class, skip this callback
75
75
  if (
76
- callback.metadata.ignore_if_exists
76
+ callback.metadata.get("ignore_if_exists", False)
77
77
  and callback_classes[callback.callback.__class__] > 1
78
78
  ):
79
79
  continue
@@ -89,7 +89,10 @@ def _process_and_filter_callbacks(
89
89
  callbacks = list(callbacks)
90
90
 
91
91
  # Sort by priority (higher priority first)
92
- callbacks.sort(key=lambda callback: callback.metadata.priority, reverse=True)
92
+ callbacks.sort(
93
+ key=lambda callback: callback.metadata.get("priority", 0),
94
+ reverse=True,
95
+ )
93
96
 
94
97
  # Process `ignore_if_exists`
95
98
  callbacks = _filter_ignore_if_exists(callbacks)
@@ -439,7 +439,8 @@ class Trainer(LightningTrainer):
439
439
  ):
440
440
  # If we have a cached path, then we symlink it to the new path.
441
441
  log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
442
- _link_checkpoint(cached_path, filepath, metadata=False)
442
+ if self.is_global_zero:
443
+ _link_checkpoint(cached_path, filepath, metadata=False)
443
444
  else:
444
445
  super().save_checkpoint(filepath, weights_only, storage_options)
445
446
 
File without changes