nshtrainer 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,6 +39,8 @@ from .shared_parameters import SharedParametersConfig as SharedParametersConfig
39
39
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
40
40
  from .timer import EpochTimer as EpochTimer
41
41
  from .timer import EpochTimerConfig as EpochTimerConfig
42
+ from .wandb_upload_code import WandbUploadCodeCallback as WandbUploadCodeCallback
43
+ from .wandb_upload_code import WandbUploadCodeConfig as WandbUploadCodeConfig
42
44
  from .wandb_watch import WandbWatchCallback as WandbWatchCallback
43
45
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
44
46
 
@@ -57,6 +59,7 @@ CallbackConfig = Annotated[
57
59
  | OnExceptionCheckpointCallbackConfig
58
60
  | SharedParametersConfig
59
61
  | RLPSanityChecksConfig
60
- | WandbWatchConfig,
62
+ | WandbWatchConfig
63
+ | WandbUploadCodeConfig,
61
64
  C.Field(discriminator="name"),
62
65
  ]
@@ -0,0 +1,79 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Literal, cast
5
+
6
+ from lightning.pytorch import LightningModule, Trainer
7
+ from lightning.pytorch.callbacks.callback import Callback
8
+ from lightning.pytorch.loggers import WandbLogger
9
+ from nshrunner._env import SNAPSHOT_DIR
10
+ from typing_extensions import override
11
+
12
+ from .base import CallbackConfigBase
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class WandbUploadCodeConfig(CallbackConfigBase):
18
+ name: Literal["wandb_upload_code"] = "wandb_upload_code"
19
+
20
+ enabled: bool = True
21
+ """Enable uploading the code to wandb."""
22
+
23
+ def __bool__(self):
24
+ return self.enabled
25
+
26
+ @override
27
+ def create_callbacks(self, root_config):
28
+ if not self:
29
+ return
30
+
31
+ yield WandbUploadCodeCallback(self)
32
+
33
+
34
+ class WandbUploadCodeCallback(Callback):
35
+ def __init__(self, config: WandbUploadCodeConfig):
36
+ super().__init__()
37
+
38
+ self.config = config
39
+
40
+ @override
41
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
42
+ if not self.config:
43
+ return
44
+
45
+ if not trainer.is_global_zero:
46
+ return
47
+
48
+ if (
49
+ logger := next(
50
+ (
51
+ logger
52
+ for logger in trainer.loggers
53
+ if isinstance(logger, WandbLogger)
54
+ ),
55
+ None,
56
+ )
57
+ ) is None:
58
+ log.warning("Wandb logger not found. Skipping code upload.")
59
+ return
60
+
61
+ from wandb.wandb_run import Run
62
+
63
+ run = cast(Run, logger.experiment)
64
+
65
+ # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
66
+ # then upload all contents within the snapshot directory to the repository.
67
+ if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
68
+ log.debug("No snapshot directory found. Skipping upload.")
69
+ return
70
+
71
+ snapshot_dir = Path(snapshot_dir)
72
+ if not snapshot_dir.exists() or not snapshot_dir.is_dir():
73
+ log.warning(
74
+ f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
75
+ )
76
+ return
77
+
78
+ log.info(f"Uploading code from snapshot directory '{snapshot_dir}'")
79
+ run.log_code(str(snapshot_dir.absolute()))
nshtrainer/config.py CHANGED
@@ -61,14 +61,17 @@ from nshtrainer.callbacks.throughput_monitor import (
61
61
  ThroughputMonitorConfig as ThroughputMonitorConfig,
62
62
  )
63
63
  from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
64
+ from nshtrainer.callbacks.wandb_upload_code import (
65
+ WandbUploadCodeConfig as WandbUploadCodeConfig,
66
+ )
64
67
  from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
68
+ from nshtrainer.config import LRSchedulerConfig as LRSchedulerConfig
65
69
  from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
66
70
  from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
67
71
  from nshtrainer.loggers.tensorboard import (
68
72
  TensorboardLoggerConfig as TensorboardLoggerConfig,
69
73
  )
70
74
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
71
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
72
75
  from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
73
76
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
74
77
  DurationConfig as DurationConfig,
@@ -5,9 +5,10 @@ from typing import TYPE_CHECKING, Literal
5
5
  import nshconfig as C
6
6
  from lightning.pytorch import Callback, LightningModule, Trainer
7
7
  from packaging import version
8
- from typing_extensions import override
8
+ from typing_extensions import assert_never, override
9
9
 
10
10
  from ..callbacks.base import CallbackConfigBase
11
+ from ..callbacks.wandb_upload_code import WandbUploadCodeConfig
11
12
  from ..callbacks.wandb_watch import WandbWatchConfig
12
13
  from ._base import BaseLoggerConfig
13
14
 
@@ -82,15 +83,18 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
82
83
  project: str | None = None
83
84
  """WandB project name to use for the logger. If None, will use the root config's project name."""
84
85
 
85
- log_model: bool | Literal["all"] = False
86
+ log_model: Literal["all", "latest", "none"] | bool = False
86
87
  """
87
88
  Whether to log the model checkpoints to wandb.
88
89
  Valid values are:
89
- - False: Do not log the model checkpoints.
90
- - True: Log the latest model checkpoint.
91
- - "all": Log all model checkpoints.
90
+ - "all": Log all checkpoints.
91
+ - "latest" or True: Log only the latest checkpoint.
92
+ - "none" or False: Do not log any checkpoints
92
93
  """
93
94
 
95
+ log_code: WandbUploadCodeConfig | None = None
96
+ """WandB code upload configuration. Used to upload code to WandB."""
97
+
94
98
  watch: WandbWatchConfig | None = WandbWatchConfig()
95
99
  """WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
96
100
 
@@ -110,6 +114,18 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
110
114
  self.use_wandb_core = value
111
115
  return self
112
116
 
117
+ @property
118
+ def _lightning_log_model(self) -> Literal["all"] | bool:
119
+ match self.log_model:
120
+ case "all":
121
+ return "all"
122
+ case "latest" | True:
123
+ return True
124
+ case "none" | False:
125
+ return False
126
+ case _:
127
+ assert_never(self.log_model)
128
+
113
129
  @override
114
130
  def create_logger(self, root_config):
115
131
  if not self.enabled:
@@ -128,11 +144,28 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
128
144
  f"(expected version >= 0.17.5, found version {wandb.__version__}). "
129
145
  "Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
130
146
  )
131
- else:
147
+ # W&B versions 0.18.0 use wandb-core by default
148
+ elif wandb_version < version.parse("0.18.0"):
132
149
  wandb.require("core") # type: ignore
133
150
  log.critical("Using the `wandb-core` backend for WandB.")
134
151
  except ImportError:
135
152
  pass
153
+ else:
154
+ # W&B versions 0.18.0 use wandb-core by default,
155
+ # so if `use_wandb_core` is False, we should use the old backend
156
+ # explicitly.
157
+ wandb_version = version.parse(importlib.metadata.version("wandb"))
158
+ if wandb_version >= version.parse("0.18.0"):
159
+ log.warning(
160
+ "Explicitly using the old backend for WandB. "
161
+ "If you want to use the new `wandb-core` backend, set `use_wandb_core=True`."
162
+ )
163
+ try:
164
+ import wandb # type: ignore
165
+
166
+ wandb.require("legacy-service") # type: ignore
167
+ except ImportError:
168
+ pass
136
169
 
137
170
  from lightning.pytorch.loggers.wandb import WandbLogger
138
171
 
@@ -145,7 +178,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
145
178
  project=self.project or _project_name(root_config),
146
179
  name=root_config.run_name,
147
180
  version=root_config.id,
148
- log_model=self.log_model,
181
+ log_model=self._lightning_log_model,
149
182
  notes=(
150
183
  "\n".join(f"- {note}" for note in root_config.notes)
151
184
  if root_config.notes
@@ -161,3 +194,6 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
161
194
 
162
195
  if self.watch:
163
196
  yield from self.watch.create_callbacks(root_config)
197
+
198
+ if self.log_code:
199
+ yield from self.log_code.create_callbacks(root_config)
@@ -13,11 +13,11 @@ CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
13
13
 
14
14
 
15
15
  class CallbackRegistrarModuleMixin:
16
- @override
17
- def __init__(self, *args, **kwargs):
18
- super().__init__(*args, **kwargs)
19
-
20
- self._nshtrainer_callbacks: list[CallbackFn] = []
16
+ @property
17
+ def _nshtrainer_callbacks(self) -> list[CallbackFn]:
18
+ if not hasattr(self, "_private_nshtrainer_callbacks_list"):
19
+ self._private_nshtrainer_callbacks_list = []
20
+ return self._private_nshtrainer_callbacks_list
21
21
 
22
22
  def register_callback(
23
23
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.35.0
3
+ Version: 0.36.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -6,7 +6,7 @@ nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28Ad
6
6
  nshtrainer/_directory.py,sha256=RjnW6vKTeKlz2vQWT3cG0Jje5BkFXA7HpUubDhcSiq4,2993
7
7
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
8
8
  nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
9
- nshtrainer/callbacks/__init__.py,sha256=1SBLpMsx7BzgimO35MwQViYBcbgxlkyvTMz1JKUKK-0,3060
9
+ nshtrainer/callbacks/__init__.py,sha256=y4QKDiI4IykcHncDOE-OemFFQqegxyeoD_3v9i2OwFw,3248
10
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
@@ -29,8 +29,9 @@ nshtrainer/callbacks/rlp_sanity_checks.py,sha256=c30G9jAu42QLLIS5LnusdSnI3wqyIHg
29
29
  nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5AehYDtuJjDtUum4,2922
30
30
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
31
31
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
32
+ nshtrainer/callbacks/wandb_upload_code.py,sha256=OWG4UkL2SfW6oj6AGRXeBJsZmgsqeHLW2Fj8Jm4ga3I,2298
32
33
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
33
- nshtrainer/config.py,sha256=6U7B-kCIMrfEnF_y92RuBm1WfASW7k05Zsm2uHBzRrk,8205
34
+ nshtrainer/config.py,sha256=6jApGtO9DVFoXKr9_Z7-MFG5R4WXjbpzZ6jkNI3yD-Y,8306
34
35
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
36
  nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
36
37
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
@@ -55,7 +56,7 @@ nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfE
55
56
  nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
56
57
  nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
57
58
  nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
58
- nshtrainer/loggers/wandb.py,sha256=C-yGX9e2FUSfbUxur7-meNUjpB3D8hIdVCOgPzGm3QM,5140
59
+ nshtrainer/loggers/wandb.py,sha256=td8J2v8T1nvGQI7OYQ1El6k8FGsXZxbnuY97s8KzCiY,6643
59
60
  nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
60
61
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
61
62
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=E8LW78uuby7bIsoLPpcF1bmNK4lSko-r3qPL-vuHWXQ,5370
@@ -65,7 +66,7 @@ nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ
65
66
  nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
66
67
  nshtrainer/model/base.py,sha256=NasbYZJBuEly6Hm9t9HVZk-CUHmy4T7p1v-Ye981XA4,18609
67
68
  nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
68
- nshtrainer/model/mixins/callback.py,sha256=lvX9Q2ErETXmGFd79CscSAOJAlTWq-mwMKVC0d0uH1c,2324
69
+ nshtrainer/model/mixins/callback.py,sha256=rbe8P22iEjPkH1df6rfEo3Txw7EwSz6Dkm0TWO_AysM,2419
69
70
  nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
70
71
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
71
72
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
@@ -98,6 +99,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
98
99
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
99
100
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
100
101
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
101
- nshtrainer-0.35.0.dist-info/METADATA,sha256=NBZegh-RUfnkVt_ERUPdH7fdCFZriQZXoMskq_8HB60,916
102
- nshtrainer-0.35.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
103
- nshtrainer-0.35.0.dist-info/RECORD,,
102
+ nshtrainer-0.36.0.dist-info/METADATA,sha256=eWSmYvpViGa536HEJ1zX2IUoRpt-wlgiNmfv8mPm7Yg,916
103
+ nshtrainer-0.36.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
104
+ nshtrainer-0.36.0.dist-info/RECORD,,