nshtrainer 1.3.3__py3-none-any.whl → 1.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import contextlib
4
4
  import logging
5
- import os
6
5
  import re
7
6
  from dataclasses import dataclass
8
7
  from functools import cached_property
@@ -10,7 +9,6 @@ from pathlib import Path
10
9
  from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
11
10
 
12
11
  import nshconfig as C
13
- from nshrunner._env import SNAPSHOT_DIR
14
12
  from typing_extensions import assert_never, override
15
13
 
16
14
  from ._callback import NTCallbackBase
@@ -19,6 +17,7 @@ from .callbacks.base import (
19
17
  CallbackMetadataConfig,
20
18
  callback_registry,
21
19
  )
20
+ from .util.code_upload import get_code_dir
22
21
 
23
22
  if TYPE_CHECKING:
24
23
  from huggingface_hub import HfApi # noqa: F401
@@ -319,20 +318,13 @@ class HFHubCallback(NTCallbackBase):
319
318
  def _save_code(self):
320
319
  # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
321
320
  # then upload all contents within the snapshot directory to the repository.
322
- if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
321
+ if (snapshot_dir := get_code_dir()) is None:
323
322
  log.debug("No snapshot directory found. Skipping upload.")
324
323
  return
325
324
 
326
325
  with self._with_error_handling("save code"):
327
- snapshot_dir = Path(snapshot_dir)
328
- if not snapshot_dir.exists() or not snapshot_dir.is_dir():
329
- log.warning(
330
- f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
331
- )
332
- return
333
-
334
326
  self.api.upload_folder(
335
- folder_path=str(snapshot_dir),
327
+ folder_path=str(snapshot_dir.absolute()),
336
328
  repo_id=self.repo_id,
337
329
  repo_type="model",
338
330
  path_in_repo="code", # Prefix with "code" folder
@@ -43,6 +43,13 @@ class EarlyStoppingCallbackConfig(CallbackConfigBase):
43
43
  the training will be stopped.
44
44
  """
45
45
 
46
+ skip_first_n_epochs: int = 0
47
+ """
48
+ Number of initial epochs to skip before starting to monitor for early stopping.
49
+ This helps avoid false early stopping when the model might temporarily perform worse
50
+ during early training phases.
51
+ """
52
+
46
53
  strict: bool = True
47
54
  """
48
55
  Whether to enforce that the monitored quantity must improve by at least `min_delta`
@@ -94,6 +101,16 @@ class EarlyStoppingCallback(_EarlyStopping):
94
101
  if getattr(trainer, "fast_dev_run", False):
95
102
  return
96
103
 
104
+ # Skip early stopping check for the first n epochs
105
+ if trainer.current_epoch < self.config.skip_first_n_epochs:
106
+ if self.verbose and trainer.current_epoch == 0:
107
+ self._log_info(
108
+ trainer,
109
+ f"Early stopping checks are disabled for the first {self.config.skip_first_n_epochs} epochs",
110
+ self.log_rank_zero_only,
111
+ )
112
+ return
113
+
97
114
  should_stop, reason = False, None
98
115
 
99
116
  if not should_stop:
@@ -1,16 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
- from pathlib import Path
6
4
  from typing import Literal, cast
7
5
 
8
6
  from lightning.pytorch import LightningModule, Trainer
9
7
  from lightning.pytorch.callbacks.callback import Callback
10
8
  from lightning.pytorch.loggers import WandbLogger
11
- from nshrunner._env import SNAPSHOT_DIR
12
9
  from typing_extensions import final, override
13
10
 
11
+ from ..util.code_upload import get_code_dir
14
12
  from .base import CallbackConfigBase, callback_registry
15
13
 
16
14
  log = logging.getLogger(__name__)
@@ -62,22 +60,12 @@ class WandbUploadCodeCallback(Callback):
62
60
  log.warning("Wandb logger not found. Skipping code upload.")
63
61
  return
64
62
 
65
- from wandb.wandb_run import Run
66
-
67
- run = cast(Run, logger.experiment)
68
-
69
- # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
70
- # then upload all contents within the snapshot directory to the repository.
71
- if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
72
- log.debug("No snapshot directory found. Skipping upload.")
63
+ if (snapshot_dir := get_code_dir()) is None:
64
+ log.info("No nshrunner snapshot found. Skipping code upload.")
73
65
  return
74
66
 
75
- snapshot_dir = Path(snapshot_dir)
76
- if not snapshot_dir.exists() or not snapshot_dir.is_dir():
77
- log.warning(
78
- f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
79
- )
80
- return
67
+ from wandb.wandb_run import Run
81
68
 
69
+ run = cast(Run, logger.experiment)
82
70
  log.info(f"Uploading code from snapshot directory '{snapshot_dir}'")
83
71
  run.log_code(str(snapshot_dir.absolute()))
@@ -14,7 +14,6 @@ from pathlib import Path
14
14
  from types import FrameType
15
15
  from typing import Any
16
16
 
17
- import nshrunner as nr
18
17
  import torch.utils.data
19
18
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
20
19
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
@@ -34,6 +33,12 @@ _IS_WINDOWS = platform.system() == "Windows"
34
33
 
35
34
 
36
35
  def _resolve_requeue_signals():
36
+ try:
37
+ import nshrunner as nr
38
+ except ImportError:
39
+ log.debug("nshrunner not found. Skipping signal requeueing.")
40
+ return None
41
+
37
42
  if (session := nr.Session.from_current_session()) is None:
38
43
  return None
39
44
 
@@ -52,9 +57,9 @@ class _SignalConnector(_LightningSignalConnector):
52
57
 
53
58
  signals_set = set(signals)
54
59
  valid_signals: set[signal.Signals] = signal.valid_signals()
55
- assert signals_set.issubset(
56
- valid_signals
57
- ), f"Invalid signal(s) found: {signals_set - valid_signals}"
60
+ assert signals_set.issubset(valid_signals), (
61
+ f"Invalid signal(s) found: {signals_set - valid_signals}"
62
+ )
58
63
  return signals
59
64
 
60
65
  def _compose_and_register(
@@ -241,9 +246,9 @@ class _SignalConnector(_LightningSignalConnector):
241
246
  "Writing requeue script to exit script directory."
242
247
  )
243
248
  exit_script_dir = Path(exit_script_dir)
244
- assert (
245
- exit_script_dir.is_dir()
246
- ), f"Exit script directory {exit_script_dir} does not exist"
249
+ assert exit_script_dir.is_dir(), (
250
+ f"Exit script directory {exit_script_dir} does not exist"
251
+ )
247
252
 
248
253
  exit_script_path = exit_script_dir / f"requeue_{job_id}.sh"
249
254
  log.info(f"Writing requeue script to {exit_script_path}")
@@ -356,12 +356,20 @@ class EnvironmentSnapshotConfig(C.Config):
356
356
 
357
357
  @classmethod
358
358
  def from_current_environment(cls):
359
- draft = cls.draft()
360
- if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
361
- draft.snapshot_dir = Path(snapshot_dir)
362
- if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
363
- draft.modules = modules.split(",")
364
- return draft.finalize()
359
+ try:
360
+ import nshrunner as nr
361
+
362
+ if (session := nr.Session.from_current_session()) is None:
363
+ log.warning("No active session found, skipping snapshot information")
364
+ return cls.empty()
365
+
366
+ draft = cls.draft()
367
+ draft.snapshot_dir = session.snapshot_dir
368
+ draft.modules = session.snapshot_modules
369
+ return draft.finalize()
370
+ except ImportError:
371
+ log.warning("nshrunner not found, skipping snapshot information")
372
+ return cls.empty()
365
373
 
366
374
 
367
375
  class EnvironmentPackageConfig(C.Config):
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ log = logging.getLogger(__name__)
7
+
8
+
9
+ def get_code_dir() -> Path | None:
10
+ try:
11
+ import nshrunner as nr
12
+
13
+ if (session := nr.Session.from_current_session()) is None:
14
+ log.debug("No active session found. Skipping code upload.")
15
+ return None
16
+
17
+ # New versions of nshrunner will have the code_dir attribute
18
+ # in the session object. We should use that. Otherwise, use snapshot_dir.
19
+ try:
20
+ code_dir = session.code_dir # type: ignore
21
+ except AttributeError:
22
+ code_dir = session.snapshot_dir
23
+
24
+ if code_dir is None:
25
+ log.debug("No code directory found. Skipping code upload.")
26
+ return None
27
+
28
+ assert isinstance(code_dir, Path), (
29
+ f"Code directory should be a Path object. Got {type(code_dir)} instead."
30
+ )
31
+ if not code_dir.exists() or not code_dir.is_dir():
32
+ log.warning(
33
+ f"Code directory '{code_dir}' does not exist or is not a directory."
34
+ )
35
+ return None
36
+
37
+ return code_dir
38
+ except ImportError:
39
+ log.debug("nshrunner not found. Skipping code upload.")
40
+ return None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.3.3
3
+ Version: 1.3.5
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -15,7 +15,7 @@ Requires-Dist: GitPython ; extra == "extra"
15
15
  Requires-Dist: huggingface-hub ; extra == "extra"
16
16
  Requires-Dist: lightning
17
17
  Requires-Dist: nshconfig (>0.39)
18
- Requires-Dist: nshrunner
18
+ Requires-Dist: nshrunner ; extra == "extra"
19
19
  Requires-Dist: nshutils ; extra == "extra"
20
20
  Requires-Dist: numpy
21
21
  Requires-Dist: packaging
@@ -5,7 +5,7 @@ nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
6
6
  nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
- nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
8
+ nshtrainer/_hf_hub.py,sha256=kfN0wDxK5JWKKGZnX_706i0KXGhaS19p581LDTPxlRE,13996
9
9
  nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
10
10
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
11
  nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
@@ -17,7 +17,7 @@ nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g
17
17
  nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
18
18
  nshtrainer/callbacks/directory_setup.py,sha256=Ln6f0tCgoBscHeigIAWtCCoAmuWB-kPyaf7SylU7MYo,2773
19
19
  nshtrainer/callbacks/distributed_prediction_writer.py,sha256=PvxV9E9lHT-NQ-h1ld7WugajqiFyFXECsreUt3e7pxk,5440
20
- nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
20
+ nshtrainer/callbacks/early_stopping.py,sha256=LTwOME4-_Zld08UjOeeoNxPOg-hCN7o9MAUWVdzGDdk,5467
21
21
  nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
22
22
  nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
23
23
  nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
@@ -30,7 +30,7 @@ nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0z
30
30
  nshtrainer/callbacks/rlp_sanity_checks.py,sha256=Df9Prq2QKXnaeMBIvMQBhDhJTDeru5UbiuXJOJR16Gk,10050
31
31
  nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
32
32
  nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
33
- nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
33
+ nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
34
34
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
35
35
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
36
36
  nshtrainer/configs/__init__.py,sha256=KD3uClMwnA4LfQ7rY5phDdUbp3j8NoZfaGbGPbpaJVs,15848
@@ -146,11 +146,12 @@ nshtrainer/trainer/plugin/environment.py,sha256=SSXRWHjyFUA6oFx3duD_ZwhM59pWUjR1
146
146
  nshtrainer/trainer/plugin/io.py,sha256=OmFSKLloMypletjaUr_Ptg6LS0ljqTVIp2o4Hm3eZoE,1926
147
147
  nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLcOfPXnvH29s,663
148
148
  nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
149
- nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
149
+ nshtrainer/trainer/signal_connector.py,sha256=ZgbSkbthoe8MYN6rBoFf-7UDpQtc9fs9pG_FNvTYSfs,10962
150
150
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
151
151
  nshtrainer/trainer/trainer.py,sha256=6oky6E8cjGqUNzJGyyTO551pE9A6YueOv5oxg1fZVR0,24129
152
- nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
152
+ nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
153
153
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
154
+ nshtrainer/util/code_upload.py,sha256=CpbZEBbA8EcBElUVoCPbP5zdwtNzJhS20RLaOB-q-2k,1257
154
155
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
155
156
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
156
157
  nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
@@ -160,6 +161,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
160
161
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
161
162
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
162
163
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
163
- nshtrainer-1.3.3.dist-info/METADATA,sha256=K_xd3BrF1Yz7gGbNQgywkjysCFuwXi3GCBoQ5EaFVKY,960
164
- nshtrainer-1.3.3.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
- nshtrainer-1.3.3.dist-info/RECORD,,
164
+ nshtrainer-1.3.5.dist-info/METADATA,sha256=GUU8QgA8rxeCX1Z9FfwSvZQ46f0xsMvtm4p1Uz8uEwE,979
165
+ nshtrainer-1.3.5.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
166
+ nshtrainer-1.3.5.dist-info/RECORD,,