nshtrainer 0.26.1__py3-none-any.whl → 0.26.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -188,27 +188,19 @@ class HFHubCallback(NTCallbackBase):
188
188
 
189
189
  self.config = config
190
190
 
191
- self._repo_id = None
191
+ self._repo_id: str | None = None
192
192
  self._checksum_to_path_in_repo: dict[str, Path] = {}
193
193
 
194
194
  @override
195
195
  def setup(self, trainer, pl_module, stage):
196
- from .trainer.trainer import Trainer
197
-
198
- if not isinstance(trainer, Trainer):
199
- raise ValueError(
200
- f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
201
- )
202
-
203
196
  root_config = cast("BaseConfig", pl_module.hparams)
197
+ self._repo_id = _repo_name(self.api, root_config)
198
+
199
+ if not self.config or not trainer.is_global_zero:
200
+ return
204
201
 
205
202
  # Create the repository, if it doesn't exist
206
- self._repo_id = self.api.create_repo(
207
- repo_id=_repo_name(self.api, root_config),
208
- repo_type="model",
209
- private=self.config.auto_create.private,
210
- exist_ok=True,
211
- )
203
+ self._create_repo_if_not_exists()
212
204
 
213
205
  # Upload the config and code
214
206
  self._save_config(root_config)
@@ -216,17 +208,22 @@ class HFHubCallback(NTCallbackBase):
216
208
 
217
209
  @override
218
210
  def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
219
- root_config = cast("BaseConfig", pl_module.hparams)
220
-
221
211
  # If HF Hub is enabled, then we upload
222
- if self.config and trainer.is_global_zero:
223
- with self._with_error_handling("save checkpoints"):
224
- self._save_checkpoint(
225
- _Upload.from_local_path(ckpt_path, root_config),
226
- _Upload.from_local_path(metadata_path, root_config)
227
- if metadata_path is not None
228
- else None,
229
- )
212
+ if (
213
+ not self.config
214
+ or not self.config.save_checkpoints
215
+ or not trainer.is_global_zero
216
+ ):
217
+ return
218
+
219
+ with self._with_error_handling("save checkpoints"):
220
+ root_config = cast("BaseConfig", pl_module.hparams)
221
+ self._save_checkpoint(
222
+ _Upload.from_local_path(ckpt_path, root_config),
223
+ _Upload.from_local_path(metadata_path, root_config)
224
+ if metadata_path is not None
225
+ else None,
226
+ )
230
227
 
231
228
  @cached_property
232
229
  def api(self):
@@ -241,6 +238,33 @@ class HFHubCallback(NTCallbackBase):
241
238
  raise ValueError("Repository id has not been initialized.")
242
239
  return self._repo_id
243
240
 
241
+ def _create_repo_if_not_exists(self):
242
+ if not self.config or not self.config.auto_create:
243
+ return
244
+
245
+ # Create the repository, if it doesn't exist
246
+ with self._with_error_handling("create repository"):
247
+ from huggingface_hub.utils import RepositoryNotFoundError
248
+
249
+ try:
250
+ # Check if the repository exists
251
+ self.api.repo_info(repo_id=self.repo_id, repo_type="model")
252
+ log.info(f"Repository '{self.repo_id}' already exists.")
253
+ except RepositoryNotFoundError:
254
+ # Repository doesn't exist, so create it
255
+ try:
256
+ self.api.create_repo(
257
+ repo_id=self.repo_id,
258
+ repo_type="model",
259
+ private=self.config.auto_create.private,
260
+ exist_ok=True,
261
+ )
262
+ log.info(f"Created new repository '{self.repo_id}'.")
263
+ except Exception:
264
+ log.exception(f"Failed to create repository '{self.repo_id}'")
265
+ except Exception:
266
+ log.exception(f"Error checking repository '{self.repo_id}'")
267
+
244
268
  def _save_config(self, root_config: "BaseConfig"):
245
269
  with self._with_error_handling("upload config"):
246
270
  self.api.upload_file(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.26.1
3
+ Version: 0.26.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -4,7 +4,7 @@ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uP
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
5
5
  nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
7
+ nshtrainer/_hf_hub.py,sha256=v1DV3Vn6Pdwr2KYI7yL_Xv1dyJLaG7yZhKFWZFy3QFk,13087
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.26.1.dist-info/METADATA,sha256=tMMpyg1BTKec5d69ziW6XBxDXaI0gSK5tDMPCmj7VCQ,916
91
- nshtrainer-0.26.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.26.1.dist-info/RECORD,,
90
+ nshtrainer-0.26.2.dist-info/METADATA,sha256=d3vWjdB9FT6fbWJPkyBI-4M18ekg2WcJmJiKasExchM,916
91
+ nshtrainer-0.26.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.26.2.dist-info/RECORD,,