nshtrainer 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import torch
12
12
 
13
13
  from ..util._environment_info import EnvironmentConfig
14
- from ..util.path import get_relative_path
14
+ from ..util.path import compute_file_checksum, get_relative_path
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from ..model import BaseConfig, LightningModuleBase
@@ -28,6 +28,7 @@ class CheckpointMetadata(C.Config):
28
28
 
29
29
  checkpoint_path: Path
30
30
  checkpoint_filename: str
31
+ checkpoint_checksum: str
31
32
 
32
33
  run_id: str
33
34
  name: str
@@ -81,6 +82,7 @@ def _generate_checkpoint_metadata(
81
82
  # moving the checkpoint directory
82
83
  checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
83
84
  checkpoint_filename=checkpoint_path.name,
85
+ checkpoint_checksum=compute_file_checksum(checkpoint_path),
84
86
  run_id=config.id,
85
87
  name=config.run_name,
86
88
  project=config.project,
nshtrainer/_hf_hub.py CHANGED
@@ -1,7 +1,8 @@
1
- import io
1
+ import contextlib
2
2
  import logging
3
3
  import os
4
4
  import re
5
+ from dataclasses import dataclass
5
6
  from functools import cached_property
6
7
  from pathlib import Path
7
8
  from typing import TYPE_CHECKING, Any, cast
@@ -17,7 +18,6 @@ if TYPE_CHECKING:
17
18
  from huggingface_hub import HfApi # noqa: F401
18
19
 
19
20
  from .model.base import BaseConfig
20
- from .trainer.trainer import Trainer
21
21
 
22
22
 
23
23
  log = logging.getLogger(__name__)
@@ -109,24 +109,6 @@ def _api(token: str | None = None):
109
109
  return api
110
110
 
111
111
 
112
- def _enabled_and_valid(
113
- trainer: "Trainer",
114
- callback: "HFHubCallback",
115
- config: HuggingFaceHubConfig,
116
- *,
117
- rank_zero_only: bool,
118
- ):
119
- # Make sure this is enabled and the config is valid
120
- if not config:
121
- return None
122
-
123
- # If `rank_zero_only` and this is not rank 0, stop here.
124
- if rank_zero_only and not trainer.is_global_zero:
125
- return None
126
-
127
- return callback._hf_hub_api
128
-
129
-
130
112
  def _repo_name(api: "HfApi", root_config: "BaseConfig"):
131
113
  username = None
132
114
  if (ac := root_config.trainer.hf_hub.auto_create) and ac.namespace:
@@ -162,357 +144,210 @@ def _repo_name(api: "HfApi", root_config: "BaseConfig"):
162
144
  return f"{username}/{repo_name}"
163
145
 
164
146
 
165
- def _init(*, trainer: "Trainer", callback: "HFHubCallback", root_config: "BaseConfig"):
166
- config = root_config.trainer.hf_hub
167
- if (
168
- api := _enabled_and_valid(
169
- trainer,
170
- callback,
171
- config,
172
- rank_zero_only=True,
147
+ @dataclass
148
+ class _Upload:
149
+ local_path: Path
150
+ path_in_repo: Path
151
+
152
+ @classmethod
153
+ def from_local_path(
154
+ cls,
155
+ local_path: Path,
156
+ root_config: "BaseConfig",
157
+ ):
158
+ # Resolve the checkpoint directory
159
+ checkpoint_dir = root_config.directory.resolve_subdirectory(
160
+ root_config.id, "checkpoint"
173
161
  )
174
- ) is None or not config.auto_create:
175
- return
176
-
177
- from huggingface_hub.utils import RepositoryNotFoundError
178
-
179
- # Resolve the repository name
180
- repo_name = _repo_name(api, root_config)
181
162
 
182
- # Create the repository, if it doesn't exist
183
- try:
184
- # Check if the repository exists
185
- api.repo_info(repo_id=repo_name, repo_type="model")
186
- log.info(f"Repository '{repo_name}' already exists.")
187
- except RepositoryNotFoundError:
188
- # Repository doesn't exist, so create it
189
163
  try:
190
- api.create_repo(
191
- repo_id=repo_name,
192
- repo_type="model",
193
- private=config.auto_create.private,
194
- exist_ok=True,
164
+ relative_path = local_path.relative_to(checkpoint_dir)
165
+ except ValueError:
166
+ raise ValueError(
167
+ f"Checkpoint path {local_path} is not within the checkpoint directory {checkpoint_dir}."
195
168
  )
196
- log.info(f"Created new repository '{repo_name}'.")
197
- except Exception:
198
- log.exception(f"Failed to create repository '{repo_name}'")
199
- except Exception:
200
- log.exception(f"Error checking repository '{repo_name}'")
201
-
202
- # Upload the config
203
- _save_config(root_config, trainer=trainer, callback=callback)
204
-
205
- # Upload the code
206
- _save_code(repo_name, config=config, trainer=trainer, callback=callback)
207
-
208
-
209
- def _save_code(
210
- repo_name: str,
211
- *,
212
- config: HuggingFaceHubConfig,
213
- trainer: "Trainer",
214
- callback: "HFHubCallback",
215
- ):
216
- if (
217
- api := _enabled_and_valid(
218
- trainer,
219
- callback,
220
- config,
221
- rank_zero_only=True,
222
- )
223
- ) is None or not config.save_code:
224
- return
225
-
226
- # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
227
- # then upload all contents within the snapshot directory to the repository.
228
- snapshot_dir = os.environ.get(SNAPSHOT_DIR)
229
- if not snapshot_dir:
230
- log.info("No snapshot directory found. Skipping upload.")
231
- return
232
-
233
- snapshot_path = Path(snapshot_dir)
234
- if not snapshot_path.exists() or not snapshot_path.is_dir():
235
- log.warning(
236
- f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
237
- )
238
- return
239
169
 
240
- try:
241
- api.upload_folder(
242
- folder_path=str(snapshot_path),
243
- repo_id=repo_name,
244
- repo_type="model",
245
- path_in_repo="code", # Prefix with "code" folder
246
- run_as_future=cast(Any, config.save_in_background),
247
- )
248
- log.info(
249
- f"Uploaded snapshot contents to repository '{repo_name}' under 'code' folder."
250
- )
251
- except Exception:
252
- log.exception(
253
- f"Failed to upload snapshot contents to repository '{repo_name}' under 'code' folder."
254
- )
170
+ # Prefix the path in repo with "checkpoints"
171
+ path_in_repo = Path("checkpoints") / relative_path
255
172
 
173
+ return cls(local_path=local_path, path_in_repo=path_in_repo)
256
174
 
257
- def _save_config(
258
- root_config: "BaseConfig",
259
- *,
260
- trainer: "Trainer",
261
- callback: "HFHubCallback",
262
- ):
263
- config = root_config.trainer.hf_hub
264
- if (
265
- api := _enabled_and_valid(
266
- trainer,
267
- callback,
268
- config,
269
- rank_zero_only=True,
270
- )
271
- ) is None or not config.save_config:
272
- return
273
175
 
274
- # Convert the root config to a JSON string
275
- # NOTE: This is a utf-8 string.
276
- config_json = root_config.model_dump_json(indent=4)
176
+ class HFHubCallback(NTCallbackBase):
177
+ @contextlib.contextmanager
178
+ def _with_error_handling(self, opeartion: str):
179
+ try:
180
+ yield
181
+ except Exception:
182
+ log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
183
+ else:
184
+ log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
277
185
 
278
- # Resolve the repository name
279
- repo_name = _repo_name(api, root_config)
186
+ def __init__(self, config: HuggingFaceHubConfig):
187
+ super().__init__()
280
188
 
281
- # Upload the config file to the repository
282
- try:
283
- api.upload_file(
284
- path_or_fileobj=config_json.encode("utf-8"),
285
- path_in_repo="config.json",
286
- repo_id=repo_name,
287
- repo_type="model",
288
- run_as_future=cast(Any, config.save_in_background),
289
- )
290
- log.info(f"Uploaded config.json to repository '{repo_name}'.")
291
- except Exception:
292
- log.exception(f"Failed to upload config.json to repository '{repo_name}'.")
189
+ self.config = config
293
190
 
191
+ self._repo_id = None
192
+ self._checksum_to_path_in_repo: dict[str, Path] = {}
294
193
 
295
- def _is_link(p: Path, trainer: "Trainer"):
296
- if p.is_symlink():
297
- return True
194
+ @override
195
+ def setup(self, trainer, pl_module, stage):
196
+ from .trainer.trainer import Trainer
298
197
 
299
- try:
300
- link_path = trainer._nshtrainer_ckpt_link(p)
301
- return link_path in trainer._nshtrainer_checkpoint_link_dict
302
- except Exception:
303
- log.info(f"Failed to check if path {p} is a symlink.", exc_info=True)
198
+ if not isinstance(trainer, Trainer):
199
+ raise ValueError(
200
+ f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
201
+ )
304
202
 
305
- return False
203
+ root_config = cast("BaseConfig", pl_module.hparams)
306
204
 
205
+ # Create the repository, if it doesn't exist
206
+ self._repo_id = self.api.create_repo(
207
+ repo_id=_repo_name(self.api, root_config),
208
+ repo_type="model",
209
+ private=self.config.auto_create.private,
210
+ exist_ok=True,
211
+ )
307
212
 
308
- def _resolve_link(p: Path, trainer: "Trainer"):
309
- if p.is_symlink():
310
- return p.resolve()
213
+ # Upload the config and code
214
+ self._save_config(root_config)
215
+ self._save_code()
311
216
 
312
- try:
313
- link_path = trainer._nshtrainer_ckpt_link(p)
314
- if (
315
- resolved := trainer._nshtrainer_checkpoint_link_dict.get(link_path)
316
- ) is not None:
317
- return resolved
318
- except Exception:
319
- log.info(f"Failed to resolve symlink for path {p}.", exc_info=True)
320
-
321
- return None
322
-
323
-
324
- def _save_checkpoint_files(
325
- trainer: "Trainer",
326
- callback: "HFHubCallback",
327
- paths: list[Path],
328
- *,
329
- root_config: "BaseConfig",
330
- ):
331
- config = root_config.trainer.hf_hub
332
- if (
333
- api := _enabled_and_valid(trainer, callback, config, rank_zero_only=True)
334
- ) is None or not config.save_checkpoints:
335
- return
336
-
337
- # Resolve the checkpoint directory
338
- checkpoint_dir = root_config.directory.resolve_subdirectory(
339
- root_config.id, "checkpoint"
340
- )
341
-
342
- # Resolve the repository name
343
- repo_name = _repo_name(api, root_config)
344
-
345
- # Let's read all the files to memory right now,
346
- # in case they get used/removed by other processes.
347
- # Read all the files to memory
348
- file_contents: list[bytes | None] = []
349
- for p in paths:
350
- assert not _is_link(p, trainer=trainer), f"Path {p} is a symlink."
351
- assert p.is_file(), f"Path {p} is not a file."
352
- try:
353
- with open(p, "rb") as f:
354
- file_contents.append(f.read())
355
- except IOError:
356
- log.exception(f"Failed to read checkpoint file {p}.")
357
- file_contents.append(None)
358
-
359
- # Remove the paths that failed to read
360
- file_contents_and_paths = [
361
- (contents, p)
362
- for contents, p in zip(file_contents, paths)
363
- if contents is not None
364
- ]
365
-
366
- # Upload the checkpoint files to the repository
367
- for contents, p in file_contents_and_paths:
368
- try:
369
- relative_path = p.relative_to(checkpoint_dir)
370
- except ValueError:
371
- log.warning(
372
- f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
373
- )
374
- continue
217
+ @override
218
+ def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
219
+ root_config = cast("BaseConfig", pl_module.hparams)
375
220
 
376
- # Prefix the path in repo with "checkpoints"
377
- path_in_repo = Path("checkpoints") / relative_path
221
+ # If HF Hub is enabled, then we upload
222
+ if self.config and trainer.is_global_zero:
223
+ with self._with_error_handling("save checkpoints"):
224
+ self._save_checkpoint(
225
+ _Upload.from_local_path(ckpt_path, root_config),
226
+ _Upload.from_local_path(metadata_path, root_config)
227
+ if metadata_path is not None
228
+ else None,
229
+ )
378
230
 
379
- # Upload the checkpoint file to the repository
380
- try:
381
- api.upload_file(
382
- path_or_fileobj=io.BytesIO(contents),
383
- path_in_repo=str(path_in_repo),
384
- repo_id=repo_name,
231
+ @cached_property
232
+ def api(self):
233
+ # Create and authenticate the API instance
234
+ if (api := _api(self.config.token)) is None:
235
+ raise ValueError("Failed to create Hugging Face Hub API instance.")
236
+ return api
237
+
238
+ @property
239
+ def repo_id(self):
240
+ if self._repo_id is None:
241
+ raise ValueError("Repository id has not been initialized.")
242
+ return self._repo_id
243
+
244
+ def _save_config(self, root_config: "BaseConfig"):
245
+ with self._with_error_handling("upload config"):
246
+ self.api.upload_file(
247
+ path_or_fileobj=root_config.model_dump_json(indent=4).encode("utf-8"),
248
+ path_in_repo="config.json",
249
+ repo_id=self.repo_id,
385
250
  repo_type="model",
386
- run_as_future=cast(Any, config.save_in_background),
251
+ run_as_future=cast(Any, self.config.save_in_background),
387
252
  )
388
- log.info(
389
- f"Uploaded checkpoint file {relative_path} to repository '{repo_name}'."
390
- )
391
- except Exception:
392
- log.exception(
393
- f"Failed to upload checkpoint file {relative_path} to repository '{repo_name}'."
394
- )
395
-
396
- log.info(f"Completed uploading checkpoint files to repository '{repo_name}'.")
397
-
398
-
399
- def _save_checkpoint_symlinks(
400
- trainer: "Trainer",
401
- callback: "HFHubCallback",
402
- paths: list[Path],
403
- *,
404
- root_config: "BaseConfig",
405
- ):
406
- config = root_config.trainer.hf_hub
407
- if (
408
- api := _enabled_and_valid(trainer, callback, config, rank_zero_only=True)
409
- ) is None or not config.save_checkpoints:
410
- return
411
-
412
- # Resolve the checkpoint directory
413
- checkpoint_dir = root_config.directory.resolve_subdirectory(
414
- root_config.id, "checkpoint"
415
- )
416
-
417
- # Resolve the repository name
418
- repo_name = _repo_name(api, root_config)
419
-
420
- # Create a commit for copying the files
421
- from huggingface_hub.hf_api import CommitOperation, CommitOperationCopy
422
253
 
423
- commits: list[CommitOperation] = []
424
- for p in paths:
425
- assert _is_link(p, trainer=trainer), f"Path {p} is not a symlink."
254
+ def _save_code(self):
255
+ # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
256
+ # then upload all contents within the snapshot directory to the repository.
257
+ if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
258
+ log.debug("No snapshot directory found. Skipping upload.")
259
+ return
260
+
261
+ with self._with_error_handling("save code"):
262
+ snapshot_dir = Path(snapshot_dir)
263
+ if not snapshot_dir.exists() or not snapshot_dir.is_dir():
264
+ log.warning(
265
+ f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
266
+ )
267
+ return
426
268
 
427
- try:
428
- dest_relative_path = p.relative_to(checkpoint_dir)
429
- except ValueError:
430
- log.warning(
431
- f"Checkpoint path {p} is not within the checkpoint directory {checkpoint_dir}."
269
+ self.api.upload_folder(
270
+ folder_path=str(snapshot_dir),
271
+ repo_id=self.repo_id,
272
+ repo_type="model",
273
+ path_in_repo="code", # Prefix with "code" folder
274
+ run_as_future=cast(Any, self.config.save_in_background),
432
275
  )
433
- continue
434
276
 
435
- if (p_resolved := _resolve_link(p, trainer=trainer)) is None:
436
- log.warning(f"Failed to resolve symlink for path {p}.")
437
- continue
438
-
439
- try:
440
- source_relative_path = p_resolved.relative_to(checkpoint_dir)
441
- except ValueError:
442
- log.warning(
443
- f"Checkpoint symlink target {p_resolved} is not within the checkpoint directory {checkpoint_dir}."
277
+ def _save_file(self, p: _Upload):
278
+ with self._with_error_handling("save file"):
279
+ # Upload the checkpoint files to the repository
280
+ self.api.upload_file(
281
+ path_or_fileobj=p.local_path,
282
+ path_in_repo=str(p.path_in_repo),
283
+ repo_id=self.repo_id,
284
+ repo_type="model",
285
+ run_as_future=cast(Any, self.config.save_in_background),
444
286
  )
445
- continue
446
287
 
447
- # Prefix the path in repo with "checkpoints"
448
- dest_path_in_repo = Path("checkpoints") / dest_relative_path
449
- source_path_in_repo = Path("checkpoints") / source_relative_path
450
-
451
- # Create and append a CommitOperationCopy for copying the symlink
452
- copy_op = CommitOperationCopy(
453
- src_path_in_repo=str(source_path_in_repo),
454
- path_in_repo=str(dest_path_in_repo),
455
- )
456
- commits.append(copy_op)
457
-
458
- log.info(f"Creating a commit with the following operations: {commits}")
288
+ def _copy_file(self, source_path_in_repo: Path, dest_path_in_repo: Path):
289
+ # Create a commit for copying the files
290
+ from huggingface_hub.hf_api import CommitOperationCopy
459
291
 
460
- try:
461
- api.create_commit(
462
- repo_id=repo_name,
463
- repo_type="model",
464
- commit_message="Copy checkpoint symlinks",
465
- operations=commits,
466
- run_as_future=cast(Any, config.save_in_background),
467
- )
468
- log.info(
469
- f"Created commit to copy checkpoint symlinks to repository '{repo_name}'."
470
- )
471
- except Exception:
472
- log.exception(
473
- f"Failed to create commit to copy checkpoint symlinks to repository '{repo_name}'"
474
- )
292
+ with self._with_error_handling("copy file"):
293
+ copy_op = CommitOperationCopy(
294
+ src_path_in_repo=str(source_path_in_repo),
295
+ path_in_repo=str(dest_path_in_repo),
296
+ )
475
297
 
476
- log.info(f"Completed copying checkpoint symlinks to repository '{repo_name}'.")
298
+ self.api.create_commit(
299
+ repo_id=self.repo_id,
300
+ repo_type="model",
301
+ commit_message="Copy checkpoint file",
302
+ operations=[copy_op],
303
+ run_as_future=cast(Any, self.config.save_in_background),
304
+ )
477
305
 
306
+ def _save_checkpoint(self, path: _Upload, metadata_path: _Upload | None):
307
+ if not self.config.save_checkpoints:
308
+ return
478
309
 
479
- class HFHubCallback(NTCallbackBase):
480
- def __init__(self, config: HuggingFaceHubConfig):
481
- super().__init__()
482
- self.config = config
310
+ # If no metadata, just save regularly.
311
+ if metadata_path is None:
312
+ self._save_file(path)
313
+ return
483
314
 
484
- @override
485
- def setup(self, trainer, pl_module, stage):
486
- from .trainer.trainer import Trainer
315
+ # Otherwise, let's check to see if we've already uploaded the metadata.
316
+ # If so, we can just copy the checkpoint file.
317
+ from ._checkpoint.metadata import CheckpointMetadata
487
318
 
488
- if not isinstance(trainer, Trainer):
489
- raise ValueError(
490
- f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
319
+ metadata = CheckpointMetadata.from_file(metadata_path.local_path)
320
+ if (
321
+ existing_ckpt_path := self._checksum_to_path_in_repo.get(
322
+ metadata.checkpoint_checksum
323
+ )
324
+ ) is not None:
325
+ self._copy_file(existing_ckpt_path, path.path_in_repo)
326
+ else:
327
+ # Otherwise, we save the checkpoint & keep the checksum so we don't
328
+ # re-upload the same file again.
329
+ self._save_file(path)
330
+ self._checksum_to_path_in_repo[metadata.checkpoint_checksum] = (
331
+ path.path_in_repo
491
332
  )
492
333
 
493
- root_config = cast("BaseConfig", pl_module.hparams)
494
- _init(trainer=trainer, callback=self, root_config=root_config)
334
+ # Save the metadata file
335
+ # NOTE: This file is fairly small, so we can just upload it directly.
336
+ # No need to copy.
337
+ self._save_file(metadata_path)
495
338
 
496
339
  @override
497
- def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
498
- root_config = cast("BaseConfig", pl_module.hparams)
340
+ def state_dict(self):
341
+ return {
342
+ "repo_id": self._repo_id,
343
+ "checksum_to_path_in_repo": {
344
+ k: str(v) for k, v in self._checksum_to_path_in_repo.items()
345
+ },
346
+ }
499
347
 
500
- # If HF Hub is enabled, then we upload
501
- if root_config.trainer.hf_hub and trainer.is_global_zero:
502
- # Upload the regular files first, then the symlinks
503
- all_paths = [p for p in (ckpt_path, metadata_path) if p is not None]
504
- if regular_paths := [
505
- p for p in all_paths if not _is_link(p, trainer=trainer)
506
- ]:
507
- _save_checkpoint_files(
508
- trainer, self, regular_paths, root_config=root_config
509
- )
510
- if symlink_paths := [p for p in all_paths if _is_link(p, trainer=trainer)]:
511
- _save_checkpoint_symlinks(
512
- trainer, self, symlink_paths, root_config=root_config
513
- )
514
-
515
- @cached_property
516
- def _hf_hub_api(self):
517
- # Create and authenticate the API instance
518
- return _api(self.config.token)
348
+ @override
349
+ def load_state_dict(self, state_dict):
350
+ self._repo_id = state_dict["repo_id"]
351
+ self._checksum_to_path_in_repo = {
352
+ k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
353
+ }
@@ -70,5 +70,5 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
70
70
 
71
71
  # Events
72
72
  @override
73
- def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
73
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
74
74
  self.save_checkpoints(trainer)
@@ -39,5 +39,5 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
39
39
  return True
40
40
 
41
41
  @override
42
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
42
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
43
43
  self.save_checkpoints(trainer)
@@ -1,8 +1,8 @@
1
- import importlib.util
2
1
  import logging
3
2
  from typing import Any, Literal, Protocol, runtime_checkable
4
3
 
5
4
  import torch
5
+ import torchmetrics
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from torch.optim import Optimizer
8
8
  from typing_extensions import override
@@ -20,19 +20,12 @@ class HasGradSkippedSteps(Protocol):
20
20
 
21
21
  class GradientSkipping(Callback):
22
22
  def __init__(self, config: "GradientSkippingConfig"):
23
- if importlib.util.find_spec("torchmetrics") is not None:
24
- raise ImportError(
25
- "To use the GradientSkipping callback, please install torchmetrics: pip install torchmetrics"
26
- )
27
-
28
23
  super().__init__()
29
24
  self.config = config
30
25
 
31
26
  @override
32
27
  def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
33
28
  if not isinstance(pl_module, HasGradSkippedSteps):
34
- import torchmetrics # type: ignore
35
-
36
29
  pl_module.grad_skipped_steps = torchmetrics.SumMetric()
37
30
 
38
31
  @override
@@ -280,13 +280,6 @@ class Trainer(LightningTrainer):
280
280
  if TYPE_CHECKING:
281
281
  callbacks: list[Callback]
282
282
 
283
- def _nshtrainer_ckpt_link(self, ckpt_path: Path):
284
- root_config = cast(BaseConfig, self._base_module.hparams)
285
- ckpt_dir = root_config.directory.resolve_subdirectory(
286
- root_config.id, "checkpoint"
287
- )
288
- return str(ckpt_path.absolute().relative_to(ckpt_dir))
289
-
290
283
  @override
291
284
  def __init__(
292
285
  self,
@@ -295,7 +288,6 @@ class Trainer(LightningTrainer):
295
288
  **kwargs: Unpack[LightningTrainerKwargs],
296
289
  ):
297
290
  self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
298
- self._nshtrainer_checkpoint_link_dict = dict[str, Path]()
299
291
 
300
292
  self._pre_init(config)
301
293
 
@@ -454,9 +446,6 @@ class Trainer(LightningTrainer):
454
446
  _link_checkpoint(cached_path, filepath, metadata=False)
455
447
  else:
456
448
  shutil.copy(cached_path, filepath)
457
- self._nshtrainer_checkpoint_link_dict[
458
- self._nshtrainer_ckpt_link(filepath)
459
- ] = cached_path
460
449
  self.strategy.barrier("Trainer.save_checkpoint")
461
450
  else:
462
451
  super().save_checkpoint(filepath, weights_only, storage_options)
nshtrainer/util/path.py CHANGED
@@ -1,3 +1,4 @@
1
+ import hashlib
1
2
  import os
2
3
  from pathlib import Path
3
4
  from typing import TypeAlias
@@ -50,3 +51,20 @@ def find_symlinks(
50
51
  pass
51
52
 
52
53
  return symlinks
54
+
55
+
56
+ def compute_file_checksum(file_path: Path) -> str:
57
+ """
58
+ Calculate the SHA256 checksum of a file.
59
+
60
+ Args:
61
+ file_path (Path): The path to the file.
62
+
63
+ Returns:
64
+ str: The hexadecimal representation of the file's SHA256 checksum.
65
+ """
66
+ sha256_hash = hashlib.sha256()
67
+ with file_path.open("rb") as f:
68
+ for byte_block in iter(lambda: f.read(4096), b""):
69
+ sha256_hash.update(byte_block)
70
+ return sha256_hash.hexdigest()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.24.0
3
+ Version: 0.26.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -22,7 +22,7 @@ Requires-Dist: psutil
22
22
  Requires-Dist: pytorch-lightning
23
23
  Requires-Dist: tensorboard ; extra == "extra"
24
24
  Requires-Dist: torch
25
- Requires-Dist: torchmetrics ; extra == "extra"
25
+ Requires-Dist: torchmetrics
26
26
  Requires-Dist: typing-extensions
27
27
  Requires-Dist: wandb ; extra == "extra"
28
28
  Requires-Dist: wrapt ; extra == "extra"
@@ -1,23 +1,23 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
- nshtrainer/_checkpoint/metadata.py,sha256=E4tfiGzhnn65X95P0Y6K2d_YfPWqvHZoF0FF1-smEJc,5221
4
+ nshtrainer/_checkpoint/metadata.py,sha256=BpxC3VGrgSDHvZpc40A2icjBEDvRNGEsTis9YkFY8Kc,5341
5
5
  nshtrainer/_checkpoint/saver.py,sha256=fvRKGI5aeXtsHBOIO4cwGe__wmO-6DiD0-744VASYA4,1500
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=Ac4y7jmuAMEQOJPJgoYmiaIGlZvgyUcqpipb6fPuHSE,16587
7
+ nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
13
  nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
14
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
15
- nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
14
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
+ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
17
17
  nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
18
18
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
19
19
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
20
- nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
20
+ nshtrainer/callbacks/gradient_skipping.py,sha256=EBNkANDnD3BTszWjnG-jwY8FEj-iRqhE3e1x5LQF6M8,3393
21
21
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
22
22
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
23
23
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=vswxAhyLqTL99kRJvU4Q3uEyQT80eM3mN74yMhsyn_I,19905
81
+ nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk,19397
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
- nshtrainer/util/path.py,sha256=WbPWXpu5LIDocQihQC3-72qxN1sa6-d1kPOmKDR-NC8,1520
85
+ nshtrainer/util/path.py,sha256=RUkIOrlj9b8zPPXE3JLhdihBNitJSUWw1whZ33u-2Yk,2005
86
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.24.0.dist-info/METADATA,sha256=sYEAgF1cJ2a3Mjp137ienL4cvHswQHH76lm2GTj-Nb8,935
91
- nshtrainer-0.24.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.24.0.dist-info/RECORD,,
90
+ nshtrainer-0.26.0.dist-info/METADATA,sha256=YBlbpalQ3BX8UBF_5SHk_F7v9Nq3JMsqVf6MoqH8KzU,916
91
+ nshtrainer-0.26.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.26.0.dist-info/RECORD,,