nshtrainer 0.10.13__py3-none-any.whl → 0.10.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/wandb_watch.py +24 -24
- nshtrainer/model/config.py +37 -6
- {nshtrainer-0.10.13.dist-info → nshtrainer-0.10.15.dist-info}/METADATA +1 -1
- {nshtrainer-0.10.13.dist-info → nshtrainer-0.10.15.dist-info}/RECORD +5 -5
- {nshtrainer-0.10.13.dist-info → nshtrainer-0.10.15.dist-info}/WHEEL +0 -0
|
@@ -12,13 +12,36 @@ from .base import CallbackConfigBase
|
|
|
12
12
|
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
class WandbWatchConfig(CallbackConfigBase):
|
|
16
|
+
name: Literal["wandb_watch"] = "wandb_watch"
|
|
17
|
+
|
|
18
|
+
enabled: bool = True
|
|
19
|
+
"""Enable watching the model for wandb."""
|
|
20
|
+
|
|
21
|
+
log: str | None = None
|
|
22
|
+
"""Log type for wandb."""
|
|
23
|
+
|
|
24
|
+
log_graph: bool = True
|
|
25
|
+
"""Whether to log the graph for wandb."""
|
|
26
|
+
|
|
27
|
+
log_freq: int = 100
|
|
28
|
+
"""Log frequency for wandb."""
|
|
29
|
+
|
|
30
|
+
def __bool__(self):
|
|
31
|
+
return self.enabled
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def create_callbacks(self, root_config):
|
|
35
|
+
yield WandbWatchCallback(self)
|
|
36
|
+
|
|
37
|
+
|
|
15
38
|
@runtime_checkable
|
|
16
39
|
class _HasWandbLogModuleProtocol(Protocol):
|
|
17
40
|
def wandb_log_module(self) -> nn.Module | None: ...
|
|
18
41
|
|
|
19
42
|
|
|
20
43
|
class WandbWatchCallback(Callback):
|
|
21
|
-
def __init__(self, config:
|
|
44
|
+
def __init__(self, config: WandbWatchConfig):
|
|
22
45
|
super().__init__()
|
|
23
46
|
|
|
24
47
|
self.config = config
|
|
@@ -78,26 +101,3 @@ class WandbWatchCallback(Callback):
|
|
|
78
101
|
log_graph=self.config.log_graph,
|
|
79
102
|
)
|
|
80
103
|
setattr(pl_module, "_model_watched", True)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class WandbWatchConfig(CallbackConfigBase):
|
|
84
|
-
name: Literal["wandb_watch"] = "wandb_watch"
|
|
85
|
-
|
|
86
|
-
enabled: bool = True
|
|
87
|
-
"""Enable watching the model for wandb."""
|
|
88
|
-
|
|
89
|
-
log: str | None = None
|
|
90
|
-
"""Log type for wandb."""
|
|
91
|
-
|
|
92
|
-
log_graph: bool = True
|
|
93
|
-
"""Whether to log the graph for wandb."""
|
|
94
|
-
|
|
95
|
-
log_freq: int = 100
|
|
96
|
-
"""Log frequency for wandb."""
|
|
97
|
-
|
|
98
|
-
def __bool__(self):
|
|
99
|
-
return self.enabled
|
|
100
|
-
|
|
101
|
-
@override
|
|
102
|
-
def create_callbacks(self, root_config):
|
|
103
|
-
yield WandbWatchCallback(self)
|
nshtrainer/model/config.py
CHANGED
|
@@ -20,6 +20,7 @@ from typing import (
|
|
|
20
20
|
|
|
21
21
|
import nshconfig as C
|
|
22
22
|
import numpy as np
|
|
23
|
+
import pkg_resources
|
|
23
24
|
import torch
|
|
24
25
|
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
|
25
26
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
|
@@ -213,7 +214,7 @@ class BaseLoggerConfig(C.Config, ABC):
|
|
|
213
214
|
"""Enable this logger."""
|
|
214
215
|
|
|
215
216
|
priority: int = 0
|
|
216
|
-
"""Priority of the logger. Higher
|
|
217
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
217
218
|
|
|
218
219
|
log_dir: DirectoryPath | None = None
|
|
219
220
|
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
|
@@ -266,7 +267,8 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
266
267
|
"""Enable WandB logging."""
|
|
267
268
|
|
|
268
269
|
priority: int = 2
|
|
269
|
-
"""Priority of the logger. Higher
|
|
270
|
+
"""Priority of the logger. Higher priority loggers are created first,
|
|
271
|
+
and the highest priority logger is the "main" logger for PyTorch Lightning."""
|
|
270
272
|
|
|
271
273
|
project: str | None = None
|
|
272
274
|
"""WandB project name to use for the logger. If None, will use the root config's project name."""
|
|
@@ -286,8 +288,17 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
286
288
|
offline: bool = False
|
|
287
289
|
"""Whether to run WandB in offline mode."""
|
|
288
290
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
+
use_wandb_core: bool = False
|
|
292
|
+
"""Whether to use the new `wandb-core` backend for WandB.
|
|
293
|
+
`wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
def offline_(self, value: bool = True):
|
|
297
|
+
self.offline = value
|
|
298
|
+
return self
|
|
299
|
+
|
|
300
|
+
def core_(self, value: bool = True):
|
|
301
|
+
self.use_wandb_core = value
|
|
291
302
|
return self
|
|
292
303
|
|
|
293
304
|
@override
|
|
@@ -295,6 +306,26 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
295
306
|
if not self.enabled:
|
|
296
307
|
return None
|
|
297
308
|
|
|
309
|
+
# If `wandb-core` is enabled, we should use the new backend.
|
|
310
|
+
if self.use_wandb_core:
|
|
311
|
+
try:
|
|
312
|
+
import wandb # type: ignore
|
|
313
|
+
|
|
314
|
+
# The minimum version that supports the new backend is 0.17.5
|
|
315
|
+
if pkg_resources.parse_version(
|
|
316
|
+
wandb.__version__
|
|
317
|
+
) < pkg_resources.parse_version("0.17.5"):
|
|
318
|
+
raise ValueError(
|
|
319
|
+
"The version of WandB installed does not support the `wandb-core` backend "
|
|
320
|
+
f"(expected version >= 0.17.5, found version {wandb.__version__}). "
|
|
321
|
+
"Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
wandb.require("core")
|
|
325
|
+
log.critical("Using the `wandb-core` backend for WandB.")
|
|
326
|
+
except ImportError:
|
|
327
|
+
pass
|
|
328
|
+
|
|
298
329
|
from lightning.pytorch.loggers.wandb import WandbLogger
|
|
299
330
|
|
|
300
331
|
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
@@ -329,7 +360,7 @@ class CSVLoggerConfig(BaseLoggerConfig):
|
|
|
329
360
|
"""Enable CSV logging."""
|
|
330
361
|
|
|
331
362
|
priority: int = 0
|
|
332
|
-
"""Priority of the logger. Higher
|
|
363
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
333
364
|
|
|
334
365
|
prefix: str = ""
|
|
335
366
|
"""A string to put at the beginning of metric keys."""
|
|
@@ -383,7 +414,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
|
383
414
|
"""Enable TensorBoard logging."""
|
|
384
415
|
|
|
385
416
|
priority: int = 2
|
|
386
|
-
"""Priority of the logger. Higher
|
|
417
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
387
418
|
|
|
388
419
|
log_graph: bool = False
|
|
389
420
|
"""
|
|
@@ -23,7 +23,7 @@ nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKh
|
|
|
23
23
|
nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
|
|
24
24
|
nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
|
|
25
25
|
nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
|
|
26
|
-
nshtrainer/callbacks/wandb_watch.py,sha256=
|
|
26
|
+
nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
|
|
27
27
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
28
28
|
nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
|
|
29
29
|
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
@@ -52,7 +52,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
52
52
|
nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
|
|
53
53
|
nshtrainer/model/__init__.py,sha256=NpvyQHmGaHB8xdraHmm8l7kDHLmvJSgBNQKkfYqtgyI,1454
|
|
54
54
|
nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
|
|
55
|
-
nshtrainer/model/config.py,sha256=
|
|
55
|
+
nshtrainer/model/config.py,sha256=npR8undYPqjIGlAZpm4suRP77qE9R42G_9Y-2Am9Wh4,54780
|
|
56
56
|
nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
|
|
57
57
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
58
58
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -79,6 +79,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
79
79
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
80
80
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
81
81
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
82
|
-
nshtrainer-0.10.
|
|
83
|
-
nshtrainer-0.10.
|
|
84
|
-
nshtrainer-0.10.
|
|
82
|
+
nshtrainer-0.10.15.dist-info/METADATA,sha256=lBdMigvT3LEgOyWtMBwaRvru8XRTU8K5GQ-ll3kqwE8,696
|
|
83
|
+
nshtrainer-0.10.15.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
84
|
+
nshtrainer-0.10.15.dist-info/RECORD,,
|
|
File without changes
|