nshtrainer 0.24.0__tar.gz → 0.26.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/PKG-INFO +2 -2
  2. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/pyproject.toml +3 -10
  3. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_checkpoint/metadata.py +3 -1
  4. nshtrainer-0.26.0/src/nshtrainer/_hf_hub.py +353 -0
  5. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
  6. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  7. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/gradient_skipping.py +1 -8
  8. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/trainer.py +0 -11
  9. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/path.py +18 -0
  10. nshtrainer-0.24.0/src/nshtrainer/_hf_hub.py +0 -518
  11. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/README.md +0 -0
  12. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/__init__.py +0 -0
  13. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_callback.py +0 -0
  14. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  15. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  16. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  17. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  18. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  19. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  20. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/base.py +0 -0
  21. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  22. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  23. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  24. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  25. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/ema.py +0 -0
  26. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  27. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/interval.py +0 -0
  28. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  29. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  30. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  31. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  32. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/timer.py +0 -0
  33. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  34. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/data/__init__.py +0 -0
  35. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  36. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/data/transform.py +0 -0
  37. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/__init__.py +0 -0
  38. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/_experimental.py +0 -0
  39. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/actsave.py +0 -0
  40. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/callbacks.py +0 -0
  41. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/config.py +0 -0
  42. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/data.py +0 -0
  43. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/log.py +0 -0
  44. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  45. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/model.py +0 -0
  46. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/nn.py +0 -0
  47. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/optimizer.py +0 -0
  48. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/runner.py +0 -0
  49. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/snapshot.py +0 -0
  50. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/snoop.py +0 -0
  51. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/trainer.py +0 -0
  52. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/typecheck.py +0 -0
  53. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/ll/util.py +0 -0
  54. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/__init__.py +0 -0
  55. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/_base.py +0 -0
  56. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/csv.py +0 -0
  57. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  58. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/loggers/wandb.py +0 -0
  59. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  60. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  61. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  62. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  63. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/metrics/__init__.py +0 -0
  64. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/metrics/_config.py +0 -0
  65. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/__init__.py +0 -0
  66. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/base.py +0 -0
  67. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/config.py +0 -0
  68. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/callback.py +0 -0
  69. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/debug.py +0 -0
  70. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/distributed.py +0 -0
  71. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/logger.py +0 -0
  72. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/profiler.py +0 -0
  73. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  74. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  75. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/__init__.py +0 -0
  76. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/mlp.py +0 -0
  77. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/module_dict.py +0 -0
  78. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/module_list.py +0 -0
  79. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  80. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/optimizer.py +0 -0
  81. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/runner.py +0 -0
  82. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  83. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/__init__.py +0 -0
  84. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  85. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  86. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  87. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/_environment_info.py +0 -0
  88. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/_useful_types.py +0 -0
  89. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/environment.py +0 -0
  90. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/seed.py +0 -0
  91. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/slurm.py +0 -0
  92. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/typed.py +0 -0
  93. {nshtrainer-0.24.0 → nshtrainer-0.26.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.24.0
3
+ Version: 0.26.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -22,7 +22,7 @@ Requires-Dist: psutil
22
22
  Requires-Dist: pytorch-lightning
23
23
  Requires-Dist: tensorboard ; extra == "extra"
24
24
  Requires-Dist: torch
25
- Requires-Dist: torchmetrics ; extra == "extra"
25
+ Requires-Dist: torchmetrics
26
26
  Requires-Dist: typing-extensions
27
27
  Requires-Dist: wandb ; extra == "extra"
28
28
  Requires-Dist: wrapt ; extra == "extra"
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.24.0"
3
+ version = "0.26.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -17,7 +17,7 @@ typing-extensions = "*"
17
17
  packaging = "*"
18
18
  lightning = "*"
19
19
  pytorch-lightning = "*"
20
- torchmetrics = { version = "*", optional = true }
20
+ torchmetrics = "*"
21
21
  wrapt = { version = "*", optional = true }
22
22
  GitPython = { version = "*", optional = true }
23
23
  wandb = { version = "*", optional = true }
@@ -46,11 +46,4 @@ reportPrivateImportUsage = false
46
46
  ignore = ["F722", "F821", "E731", "E741"]
47
47
 
48
48
  [tool.poetry.extras]
49
- extra = [
50
- "torchmetrics",
51
- "wrapt",
52
- "GitPython",
53
- "wandb",
54
- "tensorboard",
55
- "huggingface-hub",
56
- ]
49
+ extra = ["wrapt", "GitPython", "wandb", "tensorboard", "huggingface-hub"]
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import torch
12
12
 
13
13
  from ..util._environment_info import EnvironmentConfig
14
- from ..util.path import get_relative_path
14
+ from ..util.path import compute_file_checksum, get_relative_path
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ..model import BaseConfig, LightningModuleBase
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
28
28
 
29
29
  checkpoint_path: Path
30
30
  checkpoint_filename: str
31
+ checkpoint_checksum: str
31
32
 
32
33
  run_id: str
33
34
  name: str
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
81
82
  # moving the checkpoint directory
82
83
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
83
84
  checkpoint_filename=checkpoint_path.name,
85
+ checkpoint_checksum=compute_file_checksum(checkpoint_path),
84
86
  run_id=config.id,
85
87
  name=config.run_name,
86
88
  project=config.project,
@@ -0,0 +1,353 @@
1
+ import contextlib
2
+ import logging
3
+ import os
4
+ import re
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, cast
9
+
10
+ import nshconfig as C
11
+ from nshrunner._env import SNAPSHOT_DIR
12
+ from typing_extensions import override
13
+
14
+ from ._callback import NTCallbackBase
15
+ from .callbacks.base import CallbackConfigBase
16
+
17
+ if TYPE_CHECKING:
18
+ from huggingface_hub import HfApi # noqa: F401
19
+
20
+ from .model.base import BaseConfig
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ class HuggingFaceHubAutoCreateConfig(C.Config):
27
+ enabled: bool = True
28
+ """Enable automatic repository creation on the Hugging Face Hub."""
29
+
30
+ private: bool = True
31
+ """Whether to create the repository as private."""
32
+
33
+ namespace: str | None = None
34
+ """The namespace to create the repository in. If `None`, the repository will be created in the user's namespace."""
35
+
36
+ def __bool__(self):
37
+ return self.enabled
38
+
39
+
40
+ class HuggingFaceHubConfig(CallbackConfigBase):
41
+ """Configuration options for Hugging Face Hub integration."""
42
+
43
+ enabled: bool = False
44
+ """Enable Hugging Face Hub integration."""
45
+
46
+ token: str | None = None
47
+ """Hugging Face Hub API token. If `None`, the token will be read from the current environment.
48
+ This needs to either be set using `huggingface-cli login` or by setting the `HUGGINGFACE_TOKEN`
49
+ environment variable."""
50
+
51
+ auto_create: HuggingFaceHubAutoCreateConfig = HuggingFaceHubAutoCreateConfig()
52
+ """Automatic repository creation configuration options."""
53
+
54
+ save_config: bool = True
55
+ """Whether to save the model configuration to the Hugging Face Hub."""
56
+
57
+ save_checkpoints: bool = True
58
+ """Whether to save checkpoints to the Hugging Face Hub."""
59
+
60
+ save_code: bool = True
61
+ """Whether to save code to the Hugging Face Hub.
62
+ This is only supported if `nshsnap` is installed and snapshotting is enabled."""
63
+
64
+ save_in_background: bool = True
65
+ """Whether to save to the Hugging Face Hub in the background.
66
+ This corresponds to setting `run_as_future=True` in the HFApi upload methods."""
67
+
68
+ def enable_(self):
69
+ self.enabled = True
70
+ return self
71
+
72
+ def disable_(self):
73
+ self.enabled = False
74
+ return self
75
+
76
+ def __bool__(self):
77
+ return self.enabled
78
+
79
+ @override
80
+ def create_callbacks(self, root_config):
81
+ yield self.with_metadata(HFHubCallback(self), ignore_if_exists=True)
82
+
83
+
84
+ def _api(token: str | None = None):
85
+ # Make sure that `huggingface_hub` is installed
86
+ try:
87
+ import huggingface_hub # noqa: F401
88
+ except ImportError:
89
+ log.exception(
90
+ "Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
91
+ )
92
+ return None
93
+
94
+ # Create and authenticate the API instance
95
+ try:
96
+ api = huggingface_hub.HfApi(token=token)
97
+
98
+ # Verify authentication
99
+ api.whoami()
100
+ except Exception:
101
+ log.exception(
102
+ "Authentication failed for Hugging Face Hub. "
103
+ "Please make sure you are logged in using `huggingface-cli login`, "
104
+ "by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
105
+ "or by providing a valid token in the configuration."
106
+ )
107
+ return None
108
+
109
+ return api
110
+
111
+
112
+ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
113
+ username = None
114
+ if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
115
+ username = ac.namespace
116
+ elif (username := api.whoami().get("name", None)) is None:
117
+ raise ValueError("Could not get username from Hugging Face Hub.")
118
+
119
+ # Sanitize the project (if it exists), run_name, and id
120
+ parts = []
121
+ if root_config.project:
122
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.project))
123
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.run_name))
124
+ parts.append(re.sub(r"[^a-zA-Z0-9-]", "-", root_config.id))
125
+
126
+ # Combine parts and ensure it starts and ends with alphanumeric characters
127
+ repo_name = "-".join(parts)
128
+ repo_name = repo_name.strip("-")
129
+ repo_name = re.sub(
130
+ r"-+", "-", repo_name
131
+ ) # Replace multiple dashes with a single dash
132
+
133
+ # Ensure the name is not longer than 96 characters (excluding username)
134
+ if len(repo_name) > 96:
135
+ repo_name = repo_name[:96].rstrip("-")
136
+
137
+ # Ensure the repo name starts with an alphanumeric character
138
+ repo_name = re.sub(r"^[^a-zA-Z0-9]+", "", repo_name)
139
+
140
+ # If the repo_name is empty after all sanitization, use a default name
141
+ if not repo_name:
142
+ repo_name = "default-repo-name"
143
+
144
+ return f"{username}/{repo_name}"
145
+
146
+
147
+ @dataclass
148
+ class _Upload:
149
+ local_path: Path
150
+ path_in_repo: Path
151
+
152
+ @classmethod
153
+ def from_local_path(
154
+ cls,
155
+ local_path: Path,
156
+ root_config: "BaseConfig",
157
+ ):
158
+ # Resolve the checkpoint directory
159
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
160
+ root_config.id, "checkpoint"
161
+ )
162
+
163
+ try:
164
+ relative_path = local_path.relative_to(checkpoint_dir)
165
+ except ValueError:
166
+ raise ValueError(
167
+ f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
168
+ )
169
+
170
+ # Prefix the path in repo with "checkpoints"
171
+ path_in_repo = Path("checkpoints") / relative_path
172
+
173
+ return cls(local_path=local_path, path_in_repo=path_in_repo)
174
+
175
+
176
+ class HFHubCallback(NTCallbackBase):
177
+ @contextlib.contextmanager
178
+ def _with_error_handling(self, opeartion: str):
179
+ try:
180
+ yield
181
+ except Exception:
182
+ log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
183
+ else:
184
+ log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
185
+
186
+ def __init__(self, config: HuggingFaceHubConfig):
187
+ super().__init__()
188
+
189
+ self.config = config
190
+
191
+ self._repo_id = None
192
+ self._checksum_to_path_in_repo: dict[str, Path] = {}
193
+
194
+ @override
195
+ def setup(self, trainer, pl_module, stage):
196
+ from .trainer.trainer import Trainer
197
+
198
+ if not isinstance(trainer, Trainer):
199
+ raise ValueError(
200
+ f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
201
+ )
202
+
203
+ root_config = cast("BaseConfig", pl_module.hparams)
204
+
205
+ # Create the repository, if it doesn't exist
206
+ self._repo_id = self.api.create_repo(
207
+ repo_id=_repo_name(self.api, root_config),
208
+ repo_type="model",
209
+ private=self.config.auto_create.private,
210
+ exist_ok=True,
211
+ )
212
+
213
+ # Upload the config and code
214
+ self._save_config(root_config)
215
+ self._save_code()
216
+
217
+ @override
218
+ def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
219
+ root_config = cast("BaseConfig", pl_module.hparams)
220
+
221
+ # If HF Hub is enabled, then we upload
222
+ if self.config and trainer.is_global_zero:
223
+ with self._with_error_handling("save checkpoints"):
224
+ self._save_checkpoint(
225
+ _Upload.from_local_path(ckpt_path, root_config),
226
+ _Upload.from_local_path(metadata_path, root_config)
227
+ if metadata_path is not None
228
+ else None,
229
+ )
230
+
231
+ @cached_property
232
+ def api(self):
233
+ # Create and authenticate the API instance
234
+ if (api := _api(self.config.token)) is None:
235
+ raise ValueError("Failed to create Hugging Face Hub API instance.")
236
+ return api
237
+
238
+ @property
239
+ def repo_id(self):
240
+ if self._repo_id is None:
241
+ raise ValueError("Repository id has not been initialized.")
242
+ return self._repo_id
243
+
244
+ def _save_config(self, root_config: "BaseConfig"):
245
+ with self._with_error_handling("upload config"):
246
+ self.api.upload_file(
247
+ path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
248
+ path_in_repo="config.json",
249
+ repo_id=self.repo_id,
250
+ repo_type="model",
251
+ run_as_future=cast(Any, self.config.save_in_background),
252
+ )
253
+
254
+ def _save_code(self):
255
+ # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
256
+ # then upload all contents within the snapshot directory to the repository.
257
+ if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
258
+ log.debug("No snapshot directory found. Skipping upload.")
259
+ return
260
+
261
+ with self._with_error_handling("save code"):
262
+ snapshot_dir = Path(snapshot_dir)
263
+ if not snapshot_dir.exists() or not snapshot_dir.is_dir():
264
+ log.warning(
265
+ f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
266
+ )
267
+ return
268
+
269
+ self.api.upload_folder(
270
+ folder_path=str(snapshot_dir),
271
+ repo_id=self.repo_id,
272
+ repo_type="model",
273
+ path_in_repo="code", # Prefix with "code" folder
274
+ run_as_future=cast(Any, self.config.save_in_background),
275
+ )
276
+
277
+ def _save_file(self, p: _Upload):
278
+ with self._with_error_handling("save file"):
279
+ # Upload the checkpoint files to the repository
280
+ self.api.upload_file(
281
+ path_or_fileobj=p.local_path,
282
+ path_in_repo=str(p.path_in_repo),
283
+ repo_id=self.repo_id,
284
+ repo_type="model",
285
+ run_as_future=cast(Any, self.config.save_in_background),
286
+ )
287
+
288
+ def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
289
+ # Create a commit for copying the files
290
+ from huggingface_hub.hf_api import CommitOperationCopy
291
+
292
+ with self._with_error_handling("copy file"):
293
+ copy_op = CommitOperationCopy(
294
+ src_path_in_repo=str(source_path_in_repo),
295
+ path_in_repo=str(dest_path_in_repo),
296
+ )
297
+
298
+ self.api.create_commit(
299
+ repo_id=self.repo_id,
300
+ repo_type="model",
301
+ commit_message="Copy checkpoint file",
302
+ operations=[copy_op],
303
+ run_as_future=cast(Any, self.config.save_in_background),
304
+ )
305
+
306
+ def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
307
+ if not self.config.save_checkpoints:
308
+ return
309
+
310
+ # If no metadata, just save regularly.
311
+ if metadata_path is None:
312
+ self._save_file(path)
313
+ return
314
+
315
+ # Otherwise, let's check to see if we've already uploaded the metadata.
316
+ # If so, we can just copy the checkpoint file.
317
+ from ._checkpoint.metadata import CheckpointMetadata
318
+
319
+ metadata = CheckpointMetadata.from_file(metadata_path.local_path)
320
+ if (
321
+ existing_ckpt_path := self._checksum_to_path_in_repo.get(
322
+ metadata.checkpoint_checksum
323
+ )
324
+ ) is not None:
325
+ self._copy_file(existing_ckpt_path, path.path_in_repo)
326
+ else:
327
+ # Otherwise, we save the checkpoint & keep the checksum so we don't
328
+ # re-upload the same file again.
329
+ self._save_file(path)
330
+ self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
331
+ path.path_in_repo
332
+ )
333
+
334
+ # Save the metadata file
335
+ # NOTE: This file is fairly small, so we can just upload it directly.
336
+ # No need to copy.
337
+ self._save_file(metadata_path)
338
+
339
+ @override
340
+ def state_dict(self):
341
+ return {
342
+ "repo_id": self._repo_id,
343
+ "checksum_to_path_in_repo": {
344
+ k: str(v) for k, v in self._checksum_to_path_in_repo.items()
345
+ },
346
+ }
347
+
348
+ @override
349
+ def load_state_dict(self, state_dict):
350
+ self._repo_id = state_dict["repo_id"]
351
+ self._checksum_to_path_in_repo = {
352
+ k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
353
+ }
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
70
70
 
71
71
  # Events
72
72
  @override
73
- def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
73
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
74
74
  self.save_checkpoints(trainer)
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
39
39
  return True
40
40
 
41
41
  @override
42
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
42
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
43
43
  self.save_checkpoints(trainer)
@@ -1,8 +1,8 @@
1
- import importlib.util
2
1
  import logging
3
2
  from typing import Any, Literal, Protocol, runtime_checkable
4
3
 
5
4
  import torch
5
+ import torchmetrics
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from torch.optim import Optimizer
8
8
  from typing_extensions import override
@@ -20,19 +20,12 @@ class HasGradSkippedSteps(Protocol):
20
20
 
21
21
  class GradientSkipping(Callback):
22
22
  def __init__(self, config: "GradientSkippingConfig"):
23
- if importlib.util.find_spec("torchmetrics") is not None:
24
- raise ImportError(
25
- "To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
26
- )
27
-
28
23
  super().__init__()
29
24
  self.config = config
30
25
 
31
26
  @override
32
27
  def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
33
28
  if not isinstance(pl_module, HasGradSkippedSteps):
34
- import torchmetrics # type: ignore
35
-
36
29
  pl_module.grad_skipped_steps = torchmetrics.SumMetric()
37
30
 
38
31
  @override
@@ -280,13 +280,6 @@ class Trainer(LightningTrainer):
280
280
  if TYPE_CHECKING:
281
281
  callbacks: list[Callback]
282
282
 
283
- def _nshtrainer_ckpt_link(self, ckpt_path: Path):
284
- root_config = cast(BaseConfig, self._base_module.hparams)
285
- ckpt_dir = root_config.directory.resolve_subdirectory(
286
- root_config.id, "checkpoint"
287
- )
288
- return str(ckpt_path.absolute().relative_to(ckpt_dir))
289
-
290
283
  @override
291
284
  def __init__(
292
285
  self,
@@ -295,7 +288,6 @@ class Trainer(LightningTrainer):
295
288
  **kwargs: Unpack[LightningTrainerKwargs],
296
289
  ):
297
290
  self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
298
- self._nshtrainer_checkpoint_link_dict = dict[str, Path]()
299
291
 
300
292
  self._pre_init(config)
301
293
 
@@ -454,9 +446,6 @@ class Trainer(LightningTrainer):
454
446
  _link_checkpoint(cached_path, filepath, metadata=False)
455
447
  else:
456
448
  shutil.copy(cached_path, filepath)
457
- self._nshtrainer_checkpoint_link_dict[
458
- self._nshtrainer_ckpt_link(filepath)
459
- ] = cached_path
460
449
  self.strategy.barrier("Trainer.save_checkpoint")
461
450
  else:
462
451
  super().save_checkpoint(filepath, weights_only, storage_options)
@@ -1,3 +1,4 @@
1
+ import hashlib
1
2
  import os
2
3
  from pathlib import Path
3
4
  from typing import TypeAlias
@@ -50,3 +51,20 @@ def find_symlinks(
50
51
  pass
51
52
 
52
53
  return symlinks
54
+
55
+
56
+ def compute_file_checksum(file_path: Path) -> str:
57
+ """
58
+ Calculate the SHA256 checksum of a file.
59
+
60
+ Args:
61
+ file_path (Path): The path to the file.
62
+
63
+ Returns:
64
+ str: The hexadecimal representation of the file's SHA256 checksum.
65
+ """
66
+ sha256_hash = hashlib.sha256()
67
+ with file_path.open("rb") as f:
68
+ for byte_block in iter(lambda: f.read(4096), b""):
69
+ sha256_hash.update(byte_block)
70
+ return sha256_hash.hexdigest()