nshtrainer 0.17.1__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -112,8 +112,10 @@ def _write_checkpoint_metadata(
112
112
  metadata_path.write_text(metadata.model_dump_json(indent=4), encoding="utf-8")
113
113
  except Exception:
114
114
  log.exception(f"Failed to write metadata to {checkpoint_path}")
115
- else:
116
- log.debug(f"Checkpoint metadata written to {checkpoint_path}")
115
+ return None
116
+
117
+ log.debug(f"Checkpoint metadata written to {checkpoint_path}")
118
+ return checkpoint_path
117
119
 
118
120
 
119
121
  def _remove_checkpoint_metadata(checkpoint_path: Path):
nshtrainer/_hf_hub.py ADDED
@@ -0,0 +1,347 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, cast
5
+
6
+ import nshconfig as C
7
+ from lightning.pytorch import Callback
8
+ from lightning.pytorch.trainer import Trainer
9
+ from nshrunner._env import SNAPSHOT_DIR
10
+ from typing_extensions import override
11
+
12
+ from .callbacks.base import (
13
+ CallbackConfigBase,
14
+ CallbackMetadataConfig,
15
+ CallbackWithMetadata,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from huggingface_hub import HfApi # noqa: F401
20
+
21
+ from .model.base import BaseConfig
22
+ from .trainer.trainer import Trainer
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ class HuggingFaceHubAutoCreateConfig(C.Config):
27
+ enabled: bool = True
28
+ """Enable automatic repository creation on the Hugging Face Hub."""
29
+
30
+ private: bool = True
31
+ """Whether to create the repository as private."""
32
+
33
+ namespace: str | None = None
34
+ """The namespace to create the repository in. If `None`, the repository will be created in the user's namespace."""
35
+
36
+ def __bool__(self):
37
+ return self.enabled
38
+
39
+
40
+ class HuggingFaceHubConfig(CallbackConfigBase):
41
+ """Configuration options for Hugging Face Hub integration."""
42
+
43
+ enabled: bool = False
44
+ """Enable Hugging Face Hub integration."""
45
+
46
+ token: str | None = None
47
+ """Hugging Face Hub API token. If `None`, the token will be read from the current environment.
48
+ This needs to either be set using `huggingface-cli login` or by setting the `HUGGINGFACE_TOKEN`
49
+ environment variable."""
50
+
51
+ auto_create: HuggingFaceHubAutoCreateConfig = HuggingFaceHubAutoCreateConfig()
52
+ """Automatic repository creation configuration options."""
53
+
54
+ save_config: bool = True
55
+ """Whether to save the model configuration to the Hugging Face Hub."""
56
+
57
+ save_checkpoints: bool = True
58
+ """Whether to save checkpoints to the Hugging Face Hub."""
59
+
60
+ save_code: bool = True
61
+ """Whether to save code to the Hugging Face Hub.
62
+ This is only supported if `nshsnap` is installed and snapshotting is enabled."""
63
+
64
+ save_in_background: bool = True
65
+ """Whether to save to the Hugging Face Hub in the background.
66
+ This corresponds to setting `run_as_future=True` in the HFApi upload methods."""
67
+
68
+ def enable_(self):
69
+ self.enabled = True
70
+ return self
71
+
72
+ def disable_(self):
73
+ self.enabled = False
74
+ return self
75
+
76
+ def __bool__(self):
77
+ return self.enabled
78
+
79
+ @override
80
+ def create_callbacks(self, root_config):
81
+ yield CallbackWithMetadata(
82
+ HFHubCallback(self),
83
+ CallbackMetadataConfig(ignore_if_exists=True),
84
+ )
85
+
86
+
87
+ def _api(token: str | None = None):
88
+ # Make sure that `huggingface_hub` is installed
89
+ try:
90
+ import huggingface_hub # noqa: F401
91
+ except ImportError:
92
+ log.exception(
93
+ "Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
94
+ )
95
+ return None
96
+
97
+ # Create and authenticate the API instance
98
+ try:
99
+ api = huggingface_hub.HfApi(token=token)
100
+
101
+ # Verify authentication
102
+ api.whoami()
103
+ except Exception as e:
104
+ log.exception(
105
+ f"Authentication failed for Hugging Face Hub: {str(e)}. "
106
+ "Please make sure you are logged in using `huggingface-cli login`, "
107
+ "by setting the HUGGING_FACE_HUB_TOKEN environment variable, "
108
+ "or by providing a valid token in the configuration."
109
+ )
110
+ return None
111
+
112
+ return api
113
+
114
+
115
+ def _enabled_and_valid(
116
+ trainer: "Trainer",
117
+ config: HuggingFaceHubConfig,
118
+ *,
119
+ rank_zero_only: bool,
120
+ ):
121
+ # Make sure this is enabled and the config is valid
122
+ if not config:
123
+ return None
124
+
125
+ # If `rank_zero_only` and this is not rank 0, stop here.
126
+ if rank_zero_only and not trainer.is_global_zero:
127
+ return
128
+
129
+ # Make sure that `huggingface_hub` is installed
130
+ try:
131
+ import huggingface_hub # noqa: F401
132
+ except ImportError:
133
+ log.exception(
134
+ "Could not import `huggingface_hub`. Please install it using `pip install huggingface_hub`."
135
+ )
136
+ return None
137
+
138
+ # Create and authenticate the API instance
139
+ if (api := getattr(trainer, "_hf_hub_api", None)) is None:
140
+ api = _api(config.token)
141
+ setattr(trainer, "_hf_hub_api", api)
142
+ return cast(huggingface_hub.HfApi, api)
143
+
144
+
145
+ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
146
+ username = None
147
+ if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
148
+ username = ac.namespace
149
+ elif (username := api.whoami().get("name", None)) is None:
150
+ raise ValueError("Could not get username from Hugging Face Hub.")
151
+
152
+ return f"{username}/{root_config.project}-{root_config.run_name}-{root_config.id}"
153
+
154
+
155
+ def _init(*, trainer: "Trainer", root_config: "BaseConfig"):
156
+ config = root_config.trainer.hf_hub
157
+ if (
158
+ api := _enabled_and_valid(
159
+ trainer,
160
+ config,
161
+ rank_zero_only=True,
162
+ )
163
+ ) is None or not config.auto_create:
164
+ return
165
+
166
+ from huggingface_hub.utils import RepositoryNotFoundError
167
+
168
+ # Resolve the repository name
169
+ repo_name = _repo_name(api, root_config)
170
+
171
+ # Create the repository, if it doesn't exist
172
+ try:
173
+ # Check if the repository exists
174
+ api.repo_info(repo_id=repo_name, repo_type="model")
175
+ log.info(f"Repository '{repo_name}' already exists.")
176
+ except RepositoryNotFoundError:
177
+ # Repository doesn't exist, so create it
178
+ try:
179
+ api.create_repo(
180
+ repo_id=repo_name,
181
+ repo_type="model",
182
+ private=config.auto_create.private,
183
+ exist_ok=True,
184
+ )
185
+ log.info(f"Created new repository '{repo_name}'.")
186
+ except Exception as e:
187
+ log.exception(f"Failed to create repository '{repo_name}': {str(e)}")
188
+ except Exception as e:
189
+ log.exception(f"Error checking repository '{repo_name}': {str(e)}")
190
+
191
+ # Upload the config
192
+ _save_config(root_config, trainer=trainer)
193
+
194
+ # Upload the code
195
+ _save_code(repo_name, config=config, trainer=trainer)
196
+
197
+
198
+ def _save_code(
199
+ repo_name: str,
200
+ *,
201
+ config: HuggingFaceHubConfig,
202
+ trainer: "Trainer",
203
+ ):
204
+ if (
205
+ api := _enabled_and_valid(
206
+ trainer,
207
+ config,
208
+ rank_zero_only=True,
209
+ )
210
+ ) is None or not config.save_code:
211
+ return
212
+
213
+ # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
214
+ # then upload all contents within the snapshot directory to the repository.
215
+ snapshot_dir = os.environ.get(SNAPSHOT_DIR)
216
+ if not snapshot_dir:
217
+ log.info("No snapshot directory found. Skipping upload.")
218
+ return
219
+
220
+ snapshot_path = Path(snapshot_dir)
221
+ if not snapshot_path.exists() or not snapshot_path.is_dir():
222
+ log.warning(
223
+ f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
224
+ )
225
+ return
226
+
227
+ try:
228
+ api.upload_folder(
229
+ folder_path=str(snapshot_path),
230
+ repo_id=repo_name,
231
+ repo_type="model",
232
+ path_in_repo="code", # Prefix with "code" folder
233
+ run_as_future=cast(Any, config.save_in_background),
234
+ )
235
+ log.info(
236
+ f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
237
+ )
238
+ except Exception as e:
239
+ log.exception(
240
+ f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder: {str(e)}"
241
+ )
242
+
243
+
244
+ def _save_config(
245
+ root_config: "BaseConfig",
246
+ *,
247
+ trainer: "Trainer",
248
+ ):
249
+ config = root_config.trainer.hf_hub
250
+ if (
251
+ api := _enabled_and_valid(
252
+ trainer,
253
+ config,
254
+ rank_zero_only=True,
255
+ )
256
+ ) is None or not config.save_config:
257
+ return
258
+
259
+ # Convert the root config to a JSON string
260
+ # NOTE: This is a utf-8 string.
261
+ config_json = root_config.model_dump_json(indent=4)
262
+
263
+ # Resolve the repository name
264
+ repo_name = _repo_name(api, root_config)
265
+
266
+ # Upload the config file to the repository
267
+ try:
268
+ api.upload_file(
269
+ path_or_fileobj=config_json.encode("utf-8"),
270
+ path_in_repo="config.json",
271
+ repo_id=repo_name,
272
+ repo_type="model",
273
+ run_as_future=cast(Any, config.save_in_background),
274
+ )
275
+ log.info(f"Uploaded config.json to repository '{repo_name}'.")
276
+ except Exception as e:
277
+ log.exception(
278
+ f"Failed to upload config.json to repository '{repo_name}': {str(e)}"
279
+ )
280
+
281
+
282
+ def _save_checkpoint_files(
283
+ trainer: "Trainer",
284
+ paths: list[Path],
285
+ *,
286
+ root_config: "BaseConfig",
287
+ ):
288
+ config = root_config.trainer.hf_hub
289
+ if (
290
+ api := _enabled_and_valid(trainer, config, rank_zero_only=True)
291
+ ) is None or not config.save_checkpoints:
292
+ return
293
+
294
+ # Resolve the checkpoint directory
295
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
296
+ root_config.id, "checkpoint"
297
+ )
298
+
299
+ # Resolve the repository name
300
+ repo_name = _repo_name(api, root_config)
301
+
302
+ for p in paths:
303
+ try:
304
+ relative_path = p.relative_to(checkpoint_dir)
305
+ except ValueError:
306
+ log.warning(
307
+ f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
308
+ )
309
+ continue
310
+
311
+ # Prefix the path in repo with "checkpoints"
312
+ path_in_repo = Path("checkpoints") / relative_path
313
+
314
+ # Upload the checkpoint file to the repository
315
+ try:
316
+ api.upload_file(
317
+ path_or_fileobj=str(p.resolve().absolute()),
318
+ path_in_repo=str(path_in_repo),
319
+ repo_id=repo_name,
320
+ repo_type="model",
321
+ run_as_future=cast(Any, config.save_in_background),
322
+ )
323
+ log.info(
324
+ f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
325
+ )
326
+ except Exception as e:
327
+ log.exception(
328
+ f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}': {str(e)}"
329
+ )
330
+
331
+ log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
332
+
333
+
334
+ class HFHubCallback(Callback):
335
+ def __init__(self, config: HuggingFaceHubConfig):
336
+ super().__init__()
337
+ self.config = config
338
+
339
+ @override
340
+ def setup(self, trainer, pl_module, stage):
341
+ root_config = cast("BaseConfig", pl_module.hparams)
342
+ _init(trainer=trainer, root_config=root_config)
343
+
344
+ @override
345
+ def teardown(self, trainer, pl_module, stage):
346
+ if hasattr(trainer, "_hf_hub_api"):
347
+ delattr(trainer, "_hf_hub_api")
@@ -10,6 +10,7 @@ from .config import CheckpointSavingConfig as CheckpointSavingConfig
10
10
  from .config import DirectoryConfig as DirectoryConfig
11
11
  from .config import EarlyStoppingConfig as EarlyStoppingConfig
12
12
  from .config import GradientClippingConfig as GradientClippingConfig
13
+ from .config import HuggingFaceHubConfig as HuggingFaceHubConfig
13
14
  from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
14
15
  from .config import LoggingConfig as LoggingConfig
15
16
  from .config import MetricConfig as MetricConfig
nshtrainer/model/base.py CHANGED
@@ -192,6 +192,7 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
192
192
  hparams = self.config_cls().model_validate(hparams)
193
193
  hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
194
194
  hparams = self.pre_init_update_hparams(hparams)
195
+
195
196
  super().__init__(hparams)
196
197
 
197
198
  self.save_hyperparameters(hparams)
@@ -33,6 +33,7 @@ from lightning.pytorch.strategies.strategy import Strategy
33
33
  from typing_extensions import Self, TypedDict, TypeVar, override
34
34
 
35
35
  from .._checkpoint.loader import CheckpointLoadingConfig
36
+ from .._hf_hub import HuggingFaceHubConfig
36
37
  from ..callbacks import (
37
38
  BestCheckpointCallbackConfig,
38
39
  CallbackConfig,
@@ -819,6 +820,9 @@ class TrainerConfig(C.Config):
819
820
  checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
820
821
  """Checkpoint saving configuration options."""
821
822
 
823
+ hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
824
+ """Hugging Face Hub configuration options."""
825
+
822
826
  logging: LoggingConfig = LoggingConfig()
823
827
  """Logging/experiment tracking (e.g., WandB) configuration options."""
824
828
 
@@ -1213,4 +1217,5 @@ class BaseConfig(C.Config):
1213
1217
  yield self.trainer.checkpoint_saving
1214
1218
  yield self.trainer.logging
1215
1219
  yield self.trainer.optimizer
1220
+ yield self.trainer.hf_hub
1216
1221
  yield from self.trainer.callbacks
@@ -420,8 +420,16 @@ class Trainer(LightningTrainer):
420
420
  # Save the checkpoint metadata
421
421
  lm = self._base_module
422
422
  hparams = cast(BaseConfig, lm.hparams)
423
+ metadata_path = None
423
424
  if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
424
425
  # Generate the metadata and write to disk
425
- _write_checkpoint_metadata(self, lm, filepath)
426
+ metadata_path = _write_checkpoint_metadata(self, lm, filepath)
427
+
428
+ # If HF Hub is enabled, then we upload
429
+ if hparams.trainer.hf_hub:
430
+ from .._hf_hub import _save_checkpoint_files
431
+
432
+ files = [f for f in (filepath, metadata_path) if f is not None]
433
+ _save_checkpoint_files(self, files, root_config=hparams)
426
434
 
427
435
  return ret_val
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.17.1
3
+ Version: 0.18.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -11,6 +11,7 @@ Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Provides-Extra: extra
13
13
  Requires-Dist: GitPython ; extra == "extra"
14
+ Requires-Dist: huggingface-hub ; extra == "extra"
14
15
  Requires-Dist: lightning
15
16
  Requires-Dist: nshconfig
16
17
  Requires-Dist: nshrunner
@@ -1,8 +1,9 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjXtwg,13837
3
- nshtrainer/_checkpoint/metadata.py,sha256=FGVYqqHp5rCETcPfaoSZmGIPapE4kdYJCKSutTRERQI,5147
3
+ nshtrainer/_checkpoint/metadata.py,sha256=_9dBLJSCgi3H98-HJLgwVr8U7yHxbQA5VB9ZYMYjFj0,5181
4
4
  nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
5
5
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
6
+ nshtrainer/_hf_hub.py,sha256=b1Na0-SyOM5xlJCH8cqjk0ggEVCPMI_z770c32JIQRY,10701
6
7
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
7
8
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
8
9
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
@@ -54,9 +55,9 @@ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9b
54
55
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
55
56
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
56
57
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
57
- nshtrainer/model/__init__.py,sha256=BmqSbf6v6oyeilti4iEn_Tyrr1kRmcFcJekTb8NeglI,1315
58
- nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
59
- nshtrainer/model/config.py,sha256=D6Y-Y7GoMrpo7A2dmIqJsqc4X2IHwyl9OEHxO4uOc0g,42918
58
+ nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
59
+ nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
60
+ nshtrainer/model/config.py,sha256=147uV7IukvuYE4G_ZuQNxVjnlog1BdCrAVbcj_sx9Vs,43104
60
61
  nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
61
62
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
62
63
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -76,7 +77,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
76
77
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
77
78
  nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
78
79
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
79
- nshtrainer/trainer/trainer.py,sha256=jIqiNrq1I0f5pQP7lHshtgjCAYfpoWPoqwS74LHU9iM,17148
80
+ nshtrainer/trainer/trainer.py,sha256=xJBl8C-9SVT1ppmxTVwT1PIN8vZmE1erpKtKlsX2-8Y,17479
80
81
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
81
82
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
82
83
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
@@ -84,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
84
85
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
85
86
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
86
87
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
87
- nshtrainer-0.17.1.dist-info/METADATA,sha256=yFOvBuWFb0naLc6p82HVodctRPi1AAzWIb3wwa5Vq-I,885
88
- nshtrainer-0.17.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
89
- nshtrainer-0.17.1.dist-info/RECORD,,
88
+ nshtrainer-0.18.0.dist-info/METADATA,sha256=uKcju9SCdP6M3h-GjX0OOcpd52_cNThmUPmMYUpBIk4,935
89
+ nshtrainer-0.18.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
+ nshtrainer-0.18.0.dist-info/RECORD,,