nshtrainer 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,7 +10,7 @@ from lightning.pytorch.trainer.states import TrainerFn
10
10
  from typing_extensions import assert_never
11
11
 
12
12
  from ..metrics._config import MetricConfig
13
- from .metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
13
+ from .metadata import CheckpointMetadata
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from ..model.config import BaseConfig
@@ -263,13 +263,13 @@ def _checkpoint_candidates(
263
263
 
264
264
  # Load all checkpoints in the directory.
265
265
  # We can do this by looking for metadata files.
266
- for path in ckpt_dir.glob(f"*{METADATA_PATH_SUFFIX}"):
266
+ for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
267
267
  if (meta := _load_ckpt_meta(path, root_config)) is not None:
268
268
  yield meta
269
269
 
270
270
  # If we have a pre-empted checkpoint, load it
271
271
  if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
272
- hpc_meta_path = Path(hpc_path).with_suffix(METADATA_PATH_SUFFIX)
272
+ hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
273
273
  if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
274
274
  yield meta
275
275
 
@@ -279,7 +279,9 @@ def _additional_candidates(
279
279
  ):
280
280
  for path in additional_candidates:
281
281
  if (
282
- meta := _load_ckpt_meta(path.with_suffix(METADATA_PATH_SUFFIX), root_config)
282
+ meta := _load_ckpt_meta(
283
+ path.with_suffix(CheckpointMetadata.PATH_SUFFIX), root_config
284
+ )
283
285
  ) is None:
284
286
  continue
285
287
  yield meta
@@ -310,7 +312,7 @@ def _resolve_checkpoint(
310
312
  match strategy:
311
313
  case UserProvidedPathCheckpointStrategyConfig():
312
314
  meta = _load_ckpt_meta(
313
- strategy.path.with_suffix(METADATA_PATH_SUFFIX),
315
+ strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
314
316
  root_config,
315
317
  on_error=strategy.on_error,
316
318
  )
@@ -40,7 +40,7 @@ class CheckpointMetadata(C.Config):
40
40
  metrics: dict[str, Any]
41
41
  environment: EnvironmentConfig
42
42
 
43
- hparams: dict[str, Any] | None
43
+ hparams: Any
44
44
 
45
45
  @classmethod
46
46
  def from_file(cls, path: Path):
@@ -48,9 +48,7 @@ class CheckpointMetadata(C.Config):
48
48
 
49
49
  @classmethod
50
50
  def from_ckpt_path(cls, checkpoint_path: Path):
51
- if not (
52
- metadata_path := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
53
- ).exists():
51
+ if not (metadata_path := checkpoint_path.with_suffix(cls.PATH_SUFFIX)).exists():
54
52
  raise FileNotFoundError(
55
53
  f"Metadata file not found for checkpoint: {checkpoint_path}"
56
54
  )
@@ -94,7 +92,7 @@ def _generate_checkpoint_metadata(
94
92
  training_time=training_time,
95
93
  metrics=metrics,
96
94
  environment=config.environment,
97
- hparams=config.model_dump(mode="json"),
95
+ hparams=config.model_dump(),
98
96
  )
99
97
 
100
98
 
@@ -104,26 +102,26 @@ def _write_checkpoint_metadata(
104
102
  checkpoint_path: Path,
105
103
  ):
106
104
  config = cast("BaseConfig", model.config)
107
- metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
105
+ metadata_path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
108
106
  metadata = _generate_checkpoint_metadata(
109
107
  config, trainer, checkpoint_path, metadata_path
110
108
  )
111
109
 
112
110
  # Write the metadata to the checkpoint directory
113
111
  try:
114
- metadata_path.write_text(metadata.model_dump_json(indent=4))
115
- except Exception as e:
116
- log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
112
+ metadata_path.write_text(metadata.model_dump_json(indent=4), encoding="utf-8")
113
+ except Exception:
114
+ log.exception(f"Failed to write metadata to {checkpoint_path}")
117
115
  else:
118
116
  log.debug(f"Checkpoint metadata written to {checkpoint_path}")
119
117
 
120
118
 
121
119
  def _remove_checkpoint_metadata(checkpoint_path: Path):
122
- path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
120
+ path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
123
121
  try:
124
122
  path.unlink(missing_ok=True)
125
- except Exception as e:
126
- log.warning(f"Failed to remove {path}: {e}")
123
+ except Exception:
124
+ log.exception(f"Failed to remove {path}")
127
125
  else:
128
126
  log.debug(f"Removed {path}")
129
127
 
@@ -133,8 +131,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
133
131
  _remove_checkpoint_metadata(linked_checkpoint_path)
134
132
 
135
133
  # Link the metadata files to the new checkpoint
136
- path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
137
- linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
134
+ path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
135
+ linked_path = linked_checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
138
136
  try:
139
137
  try:
140
138
  # linked_path.symlink_to(path)
@@ -146,8 +144,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
146
144
  # on Windows, special permissions are required to create symbolic links as a regular user
147
145
  # fall back to copying the file
148
146
  shutil.copy(path, linked_path)
149
- except Exception as e:
150
- log.warning(f"Failed to link {path} to {linked_path}: {e}")
147
+ except Exception:
148
+ log.exception(f"Failed to link {path} to {linked_path}")
151
149
  else:
152
150
  log.debug(f"Linked {path} to {linked_path}")
153
151
 
@@ -102,6 +102,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
102
102
  metas = [
103
103
  CheckpointMetadata.from_file(p)
104
104
  for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
105
+ if p.is_file() and not p.is_symlink()
105
106
  ]
106
107
 
107
108
  # Sort by the topk sort key
@@ -419,7 +419,8 @@ class Trainer(LightningTrainer):
419
419
 
420
420
  # Save the checkpoint metadata
421
421
  lm = self._base_module
422
- if lm.config.trainer.save_checkpoint_metadata and self.is_global_zero:
422
+ hparams = cast(BaseConfig, lm.hparams)
423
+ if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
423
424
  # Generate the metadata and write to disk
424
425
  _write_checkpoint_metadata(self, lm, filepath)
425
426
 
@@ -429,12 +429,12 @@ class EnvironmentPackageConfig(C.Config):
429
429
  version=clean_version,
430
430
  path=Path(str(f)) if (f := dist.locate_file("")) else None,
431
431
  summary=metadata["Summary"] if "Summary" in metadata else None,
432
- author=metadata["Author"] if "Summary" in metadata else None,
433
- license=metadata["License"] if "Summary" in metadata else None,
432
+ author=metadata["Author"] if "Author" in metadata else None,
433
+ license=metadata["License"] if "License" in metadata else None,
434
434
  requires=requires,
435
435
  )
436
- except Exception as e:
437
- log.warning(f"Error processing package {dist.name}: {str(e)}")
436
+ except Exception:
437
+ log.exception(f"Error processing package {dist.name}")
438
438
 
439
439
  except ImportError:
440
440
  log.warning(
@@ -672,8 +672,8 @@ class GitRepositoryConfig(C.Config):
672
672
  draft.is_dirty = repo.is_dirty()
673
673
  except git.InvalidGitRepositoryError:
674
674
  draft.is_git_repo = False
675
- except Exception as e:
676
- log.warning(f"Failed to get Git repository information: {e}")
675
+ except Exception:
676
+ log.exception("Failed to get Git repository information")
677
677
  draft.is_git_repo = None
678
678
 
679
679
  return draft.finalize()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.16.0
3
+ Version: 0.17.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
- nshtrainer/_checkpoint/loader.py,sha256=DSaNR8194kWon4O1svslNsCcN_8vlyLbF0LNCPfUpzI,13789
3
- nshtrainer/_checkpoint/metadata.py,sha256=onmetLp5eKbA86abq1PTkwAOO7bWj7Pa1EGUjl2TEjQ,5153
2
+ nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjXtwg,13837
3
+ nshtrainer/_checkpoint/metadata.py,sha256=ilq3Wz9QmILxwpy5qYm7kphhtqPUYAOh8D-OXc5SqSc,5131
4
4
  nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
5
5
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
6
6
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
@@ -8,7 +8,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
8
8
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
9
9
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
10
10
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
11
- nshtrainer/callbacks/checkpoint/_base.py,sha256=yBUZ63P4a3m5jviIUGxBd2WXq_TigpKt3w4KdzqrzLs,6094
11
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=YT_V-oihO9iB4ETl46CGYTCQjIYl-CpV7TMViTn07Lk,6144
12
12
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
13
13
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
14
14
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -76,14 +76,14 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
76
76
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
77
77
  nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
78
78
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
79
- nshtrainer/trainer/trainer.py,sha256=M97phnALfG18VxkMLoDr5AKFf4UaPBdc6S2BghdBtas,17103
80
- nshtrainer/util/_environment_info.py,sha256=Nmhls6u5rtMWbeDLGCjEk58efUutc_ONSUg3fs59TSI,24210
79
+ nshtrainer/trainer/trainer.py,sha256=jIqiNrq1I0f5pQP7lHshtgjCAYfpoWPoqwS74LHU9iM,17148
80
+ nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
81
81
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
82
82
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
83
83
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
84
84
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
85
85
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
86
86
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
87
- nshtrainer-0.16.0.dist-info/METADATA,sha256=O8wSotKFVUpCPLCLibdBmC6kUICBBcVMXsThhwuMmKI,885
88
- nshtrainer-0.16.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
89
- nshtrainer-0.16.0.dist-info/RECORD,,
87
+ nshtrainer-0.17.0.dist-info/METADATA,sha256=6qrwuGiaXKIaFBpYwvuqBjyDhkLX4bNjU1ojJeTLsQE,885
88
+ nshtrainer-0.17.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
89
+ nshtrainer-0.17.0.dist-info/RECORD,,