nshtrainer 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/early_stopping.py +1 -3
- nshtrainer/ll/__init__.py +0 -2
- nshtrainer/loggers/__init__.py +13 -0
- nshtrainer/loggers/_base.py +26 -0
- nshtrainer/loggers/csv.py +40 -0
- nshtrainer/loggers/tensorboard.py +73 -0
- nshtrainer/loggers/wandb.py +163 -0
- nshtrainer/model/__init__.py +0 -2
- nshtrainer/model/config.py +14 -259
- nshtrainer/trainer/trainer.py +1 -1
- {nshtrainer-0.13.1.dist-info → nshtrainer-0.14.1.dist-info}/METADATA +1 -1
- {nshtrainer-0.13.1.dist-info → nshtrainer-0.14.1.dist-info}/RECORD +13 -8
- {nshtrainer-0.13.1.dist-info → nshtrainer-0.14.1.dist-info}/WHEEL +0 -0
|
@@ -73,9 +73,7 @@ class EarlyStopping(_EarlyStopping):
|
|
|
73
73
|
|
|
74
74
|
@override
|
|
75
75
|
@staticmethod
|
|
76
|
-
def _log_info(
|
|
77
|
-
trainer: Trainer | None, message: str, log_rank_zero_only: bool
|
|
78
|
-
) -> None:
|
|
76
|
+
def _log_info(trainer: Trainer | None, message: str, log_rank_zero_only: bool):
|
|
79
77
|
rank = _get_rank()
|
|
80
78
|
if trainer is not None and trainer.world_size <= 1:
|
|
81
79
|
rank = None
|
nshtrainer/ll/__init__.py
CHANGED
|
@@ -23,7 +23,6 @@ from .log import pretty as pretty
|
|
|
23
23
|
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
24
24
|
from .model import Base as Base
|
|
25
25
|
from .model import BaseConfig as BaseConfig
|
|
26
|
-
from .model import BaseLoggerConfig as BaseLoggerConfig
|
|
27
26
|
from .model import BaseProfilerConfig as BaseProfilerConfig
|
|
28
27
|
from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
29
28
|
from .model import CheckpointSavingConfig as CheckpointSavingConfig
|
|
@@ -48,7 +47,6 @@ from .model import PrimaryMetricConfig as PrimaryMetricConfig
|
|
|
48
47
|
from .model import ReproducibilityConfig as ReproducibilityConfig
|
|
49
48
|
from .model import SanityCheckingConfig as SanityCheckingConfig
|
|
50
49
|
from .model import TrainerConfig as TrainerConfig
|
|
51
|
-
from .model import WandbWatchConfig as WandbWatchConfig
|
|
52
50
|
from .nn import TypedModuleDict as TypedModuleDict
|
|
53
51
|
from .nn import TypedModuleList as TypedModuleList
|
|
54
52
|
from .optimizer import OptimizerConfig as OptimizerConfig
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Annotated, TypeAlias
|
|
2
|
+
|
|
3
|
+
import nshconfig as C
|
|
4
|
+
|
|
5
|
+
from ._base import BaseLoggerConfig as BaseLoggerConfig
|
|
6
|
+
from .csv import CSVLoggerConfig as CSVLoggerConfig
|
|
7
|
+
from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
|
8
|
+
from .wandb import WandbLoggerConfig as WandbLoggerConfig
|
|
9
|
+
|
|
10
|
+
LoggerConfig: TypeAlias = Annotated[
|
|
11
|
+
CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig,
|
|
12
|
+
C.Field(discriminator="name"),
|
|
13
|
+
]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
import nshconfig as C
|
|
5
|
+
from lightning.pytorch.loggers import Logger
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from ..model import BaseConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseLoggerConfig(C.Config, ABC):
|
|
12
|
+
enabled: bool = True
|
|
13
|
+
"""Enable this logger."""
|
|
14
|
+
|
|
15
|
+
priority: int = 0
|
|
16
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
17
|
+
|
|
18
|
+
log_dir: C.DirectoryPath | None = None
|
|
19
|
+
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
|
|
23
|
+
|
|
24
|
+
def disable_(self):
|
|
25
|
+
self.enabled = False
|
|
26
|
+
return self
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
from ._base import BaseLoggerConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CSVLoggerConfig(BaseLoggerConfig):
|
|
9
|
+
name: Literal["csv"] = "csv"
|
|
10
|
+
|
|
11
|
+
enabled: bool = True
|
|
12
|
+
"""Enable CSV logging."""
|
|
13
|
+
|
|
14
|
+
priority: int = 0
|
|
15
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
16
|
+
|
|
17
|
+
prefix: str = ""
|
|
18
|
+
"""A string to put at the beginning of metric keys."""
|
|
19
|
+
|
|
20
|
+
flush_logs_every_n_steps: int = 100
|
|
21
|
+
"""How often to flush logs to disk."""
|
|
22
|
+
|
|
23
|
+
@override
|
|
24
|
+
def create_logger(self, root_config):
|
|
25
|
+
if not self.enabled:
|
|
26
|
+
return None
|
|
27
|
+
|
|
28
|
+
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
|
29
|
+
|
|
30
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
31
|
+
root_config.id,
|
|
32
|
+
self,
|
|
33
|
+
)
|
|
34
|
+
return CSVLogger(
|
|
35
|
+
save_dir=save_dir,
|
|
36
|
+
name=root_config.run_name,
|
|
37
|
+
version=root_config.id,
|
|
38
|
+
prefix=self.prefix,
|
|
39
|
+
flush_logs_every_n_steps=self.flush_logs_every_n_steps,
|
|
40
|
+
)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import nshconfig as C
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from ._base import BaseLoggerConfig
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _tensorboard_available():
|
|
13
|
+
try:
|
|
14
|
+
from lightning.fabric.loggers.tensorboard import (
|
|
15
|
+
_TENSORBOARD_AVAILABLE,
|
|
16
|
+
_TENSORBOARDX_AVAILABLE,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
|
|
20
|
+
log.warning(
|
|
21
|
+
"TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. "
|
|
22
|
+
"Please install TensorBoard with `pip install tensorboard` or "
|
|
23
|
+
"TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging."
|
|
24
|
+
)
|
|
25
|
+
return False
|
|
26
|
+
return True
|
|
27
|
+
except ImportError:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
32
|
+
name: Literal["tensorboard"] = "tensorboard"
|
|
33
|
+
|
|
34
|
+
enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
|
|
35
|
+
"""Enable TensorBoard logging."""
|
|
36
|
+
|
|
37
|
+
priority: int = 2
|
|
38
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
39
|
+
|
|
40
|
+
log_graph: bool = False
|
|
41
|
+
"""
|
|
42
|
+
Adds the computational graph to tensorboard. This requires that
|
|
43
|
+
the user has defined the `self.example_input_array` attribute in their
|
|
44
|
+
model.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
default_hp_metric: bool = True
|
|
48
|
+
"""
|
|
49
|
+
Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
|
|
50
|
+
called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
prefix: str = ""
|
|
54
|
+
"""A string to put at the beginning of metric keys."""
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def create_logger(self, root_config):
|
|
58
|
+
if not self.enabled:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
|
62
|
+
|
|
63
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
64
|
+
root_config.id,
|
|
65
|
+
self,
|
|
66
|
+
)
|
|
67
|
+
return TensorBoardLogger(
|
|
68
|
+
save_dir=save_dir,
|
|
69
|
+
name=root_config.run_name,
|
|
70
|
+
version=root_config.id,
|
|
71
|
+
log_graph=self.log_graph,
|
|
72
|
+
default_hp_metric=self.default_hp_metric,
|
|
73
|
+
)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING, Literal
|
|
3
|
+
|
|
4
|
+
import nshconfig as C
|
|
5
|
+
import pkg_resources
|
|
6
|
+
from lightning.pytorch import Callback, LightningModule, Trainer
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from ..callbacks import WandbWatchConfig
|
|
10
|
+
from ..callbacks.base import CallbackConfigBase
|
|
11
|
+
from ._base import BaseLoggerConfig
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..model.config import BaseConfig
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _project_name(
|
|
20
|
+
root_config: "BaseConfig",
|
|
21
|
+
default_project: str = "lightning_logs",
|
|
22
|
+
):
|
|
23
|
+
# If the config has a project name, use that.
|
|
24
|
+
if project := root_config.project:
|
|
25
|
+
return project
|
|
26
|
+
|
|
27
|
+
# Otherwise, we should use the name of the module that the config is defined in,
|
|
28
|
+
# if we can find it.
|
|
29
|
+
# If this isn't in a module, use the default project name.
|
|
30
|
+
if not (module := root_config.__module__):
|
|
31
|
+
return default_project
|
|
32
|
+
|
|
33
|
+
# If the module is a package, use the package name.
|
|
34
|
+
if not (module := module.split(".", maxsplit=1)[0].strip()):
|
|
35
|
+
return default_project
|
|
36
|
+
|
|
37
|
+
return module
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _wandb_available():
|
|
41
|
+
try:
|
|
42
|
+
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
|
|
43
|
+
|
|
44
|
+
if not _WANDB_AVAILABLE:
|
|
45
|
+
log.warning("WandB not found. Disabling WandbLogger.")
|
|
46
|
+
return False
|
|
47
|
+
return True
|
|
48
|
+
except ImportError:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class FinishWandbOnTeardownCallback(Callback):
|
|
53
|
+
@override
|
|
54
|
+
def teardown(
|
|
55
|
+
self,
|
|
56
|
+
trainer: Trainer,
|
|
57
|
+
pl_module: LightningModule,
|
|
58
|
+
stage: str,
|
|
59
|
+
):
|
|
60
|
+
try:
|
|
61
|
+
import wandb # type: ignore
|
|
62
|
+
except ImportError:
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
if wandb.run is None:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
wandb.finish()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
72
|
+
name: Literal["wandb"] = "wandb"
|
|
73
|
+
|
|
74
|
+
enabled: bool = C.Field(default_factory=lambda: _wandb_available())
|
|
75
|
+
"""Enable WandB logging."""
|
|
76
|
+
|
|
77
|
+
priority: int = 2
|
|
78
|
+
"""Priority of the logger. Higher priority loggers are created first,
|
|
79
|
+
and the highest priority logger is the "main" logger for PyTorch Lightning."""
|
|
80
|
+
|
|
81
|
+
project: str | None = None
|
|
82
|
+
"""WandB project name to use for the logger. If None, will use the root config's project name."""
|
|
83
|
+
|
|
84
|
+
log_model: bool | Literal["all"] = False
|
|
85
|
+
"""
|
|
86
|
+
Whether to log the model checkpoints to wandb.
|
|
87
|
+
Valid values are:
|
|
88
|
+
- False: Do not log the model checkpoints.
|
|
89
|
+
- True: Log the latest model checkpoint.
|
|
90
|
+
- "all": Log all model checkpoints.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
watch: WandbWatchConfig | None = WandbWatchConfig()
|
|
94
|
+
"""WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
|
|
95
|
+
|
|
96
|
+
offline: bool = False
|
|
97
|
+
"""Whether to run WandB in offline mode."""
|
|
98
|
+
|
|
99
|
+
use_wandb_core: bool = True
|
|
100
|
+
"""Whether to use the new `wandb-core` backend for WandB.
|
|
101
|
+
`wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def offline_(self, value: bool = True):
|
|
105
|
+
self.offline = value
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
def core_(self, value: bool = True):
|
|
109
|
+
self.use_wandb_core = value
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def create_logger(self, root_config):
|
|
114
|
+
if not self.enabled:
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
# If `wandb-core` is enabled, we should use the new backend.
|
|
118
|
+
if self.use_wandb_core:
|
|
119
|
+
try:
|
|
120
|
+
import wandb # type: ignore
|
|
121
|
+
|
|
122
|
+
# The minimum version that supports the new backend is 0.17.5
|
|
123
|
+
if pkg_resources.parse_version(
|
|
124
|
+
wandb.__version__
|
|
125
|
+
) < pkg_resources.parse_version("0.17.5"):
|
|
126
|
+
log.warning(
|
|
127
|
+
"The version of WandB installed does not support the `wandb-core` backend "
|
|
128
|
+
f"(expected version >= 0.17.5, found version {wandb.__version__}). "
|
|
129
|
+
"Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
wandb.require("core")
|
|
133
|
+
log.critical("Using the `wandb-core` backend for WandB.")
|
|
134
|
+
except ImportError:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
from lightning.pytorch.loggers.wandb import WandbLogger
|
|
138
|
+
|
|
139
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
140
|
+
root_config.id,
|
|
141
|
+
self,
|
|
142
|
+
)
|
|
143
|
+
return WandbLogger(
|
|
144
|
+
save_dir=save_dir,
|
|
145
|
+
project=self.project or _project_name(root_config),
|
|
146
|
+
name=root_config.run_name,
|
|
147
|
+
version=root_config.id,
|
|
148
|
+
log_model=self.log_model,
|
|
149
|
+
notes=(
|
|
150
|
+
"\n".join(f"- {note}" for note in root_config.notes)
|
|
151
|
+
if root_config.notes
|
|
152
|
+
else None
|
|
153
|
+
),
|
|
154
|
+
tags=root_config.tags,
|
|
155
|
+
offline=self.offline,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def create_callbacks(self, root_config):
|
|
160
|
+
yield FinishWandbOnTeardownCallback()
|
|
161
|
+
|
|
162
|
+
if self.watch:
|
|
163
|
+
yield from self.watch.create_callbacks(root_config)
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing_extensions import TypeAlias
|
|
|
3
3
|
from .base import Base as Base
|
|
4
4
|
from .base import LightningModuleBase as LightningModuleBase
|
|
5
5
|
from .config import BaseConfig as BaseConfig
|
|
6
|
-
from .config import BaseLoggerConfig as BaseLoggerConfig
|
|
7
6
|
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
8
7
|
from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
9
8
|
from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
@@ -22,6 +21,5 @@ from .config import PrimaryMetricConfig as PrimaryMetricConfig
|
|
|
22
21
|
from .config import ReproducibilityConfig as ReproducibilityConfig
|
|
23
22
|
from .config import SanityCheckingConfig as SanityCheckingConfig
|
|
24
23
|
from .config import TrainerConfig as TrainerConfig
|
|
25
|
-
from .config import WandbWatchConfig as WandbWatchConfig
|
|
26
24
|
|
|
27
25
|
ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
|
nshtrainer/model/config.py
CHANGED
|
@@ -19,7 +19,6 @@ from typing import (
|
|
|
19
19
|
|
|
20
20
|
import nshconfig as C
|
|
21
21
|
import numpy as np
|
|
22
|
-
import pkg_resources
|
|
23
22
|
import torch
|
|
24
23
|
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
|
25
24
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
|
@@ -31,7 +30,6 @@ from lightning.pytorch.plugins.layer_sync import LayerSync
|
|
|
31
30
|
from lightning.pytorch.plugins.precision.precision import Precision
|
|
32
31
|
from lightning.pytorch.profilers import Profiler
|
|
33
32
|
from lightning.pytorch.strategies.strategy import Strategy
|
|
34
|
-
from pydantic import DirectoryPath
|
|
35
33
|
from typing_extensions import Self, TypedDict, TypeVar, override
|
|
36
34
|
|
|
37
35
|
from .._checkpoint.loader import CheckpointLoadingConfig
|
|
@@ -41,9 +39,14 @@ from ..callbacks import (
|
|
|
41
39
|
EarlyStoppingConfig,
|
|
42
40
|
LastCheckpointCallbackConfig,
|
|
43
41
|
OnExceptionCheckpointCallbackConfig,
|
|
44
|
-
WandbWatchConfig,
|
|
45
42
|
)
|
|
46
43
|
from ..callbacks.base import CallbackConfigBase
|
|
44
|
+
from ..loggers import (
|
|
45
|
+
CSVLoggerConfig,
|
|
46
|
+
LoggerConfig,
|
|
47
|
+
TensorboardLoggerConfig,
|
|
48
|
+
WandbLoggerConfig,
|
|
49
|
+
)
|
|
47
50
|
from ..metrics import MetricConfig
|
|
48
51
|
from ..util._environment_info import EnvironmentConfig
|
|
49
52
|
|
|
@@ -205,255 +208,6 @@ ProfilerConfig: TypeAlias = Annotated[
|
|
|
205
208
|
]
|
|
206
209
|
|
|
207
210
|
|
|
208
|
-
class BaseLoggerConfig(C.Config, ABC):
|
|
209
|
-
enabled: bool = True
|
|
210
|
-
"""Enable this logger."""
|
|
211
|
-
|
|
212
|
-
priority: int = 0
|
|
213
|
-
"""Priority of the logger. Higher priority loggers are created first."""
|
|
214
|
-
|
|
215
|
-
log_dir: DirectoryPath | None = None
|
|
216
|
-
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
|
217
|
-
|
|
218
|
-
@abstractmethod
|
|
219
|
-
def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
|
|
220
|
-
|
|
221
|
-
def disable_(self):
|
|
222
|
-
self.enabled = False
|
|
223
|
-
return self
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
def _project_name(
|
|
227
|
-
root_config: "BaseConfig",
|
|
228
|
-
default_project: str = "lightning_logs",
|
|
229
|
-
):
|
|
230
|
-
# If the config has a project name, use that.
|
|
231
|
-
if project := root_config.project:
|
|
232
|
-
return project
|
|
233
|
-
|
|
234
|
-
# Otherwise, we should use the name of the module that the config is defined in,
|
|
235
|
-
# if we can find it.
|
|
236
|
-
# If this isn't in a module, use the default project name.
|
|
237
|
-
if not (module := root_config.__module__):
|
|
238
|
-
return default_project
|
|
239
|
-
|
|
240
|
-
# If the module is a package, use the package name.
|
|
241
|
-
if not (module := module.split(".", maxsplit=1)[0].strip()):
|
|
242
|
-
return default_project
|
|
243
|
-
|
|
244
|
-
return module
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def _wandb_available():
|
|
248
|
-
try:
|
|
249
|
-
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
|
|
250
|
-
|
|
251
|
-
if not _WANDB_AVAILABLE:
|
|
252
|
-
log.warning("WandB not found. Disabling WandbLogger.")
|
|
253
|
-
return False
|
|
254
|
-
return True
|
|
255
|
-
except ImportError:
|
|
256
|
-
return False
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
260
|
-
name: Literal["wandb"] = "wandb"
|
|
261
|
-
|
|
262
|
-
enabled: bool = C.Field(default_factory=lambda: _wandb_available())
|
|
263
|
-
"""Enable WandB logging."""
|
|
264
|
-
|
|
265
|
-
priority: int = 2
|
|
266
|
-
"""Priority of the logger. Higher priority loggers are created first,
|
|
267
|
-
and the highest priority logger is the "main" logger for PyTorch Lightning."""
|
|
268
|
-
|
|
269
|
-
project: str | None = None
|
|
270
|
-
"""WandB project name to use for the logger. If None, will use the root config's project name."""
|
|
271
|
-
|
|
272
|
-
log_model: bool | Literal["all"] = False
|
|
273
|
-
"""
|
|
274
|
-
Whether to log the model checkpoints to wandb.
|
|
275
|
-
Valid values are:
|
|
276
|
-
- False: Do not log the model checkpoints.
|
|
277
|
-
- True: Log the latest model checkpoint.
|
|
278
|
-
- "all": Log all model checkpoints.
|
|
279
|
-
"""
|
|
280
|
-
|
|
281
|
-
watch: WandbWatchConfig = WandbWatchConfig()
|
|
282
|
-
"""WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
|
|
283
|
-
|
|
284
|
-
offline: bool = False
|
|
285
|
-
"""Whether to run WandB in offline mode."""
|
|
286
|
-
|
|
287
|
-
use_wandb_core: bool = True
|
|
288
|
-
"""Whether to use the new `wandb-core` backend for WandB.
|
|
289
|
-
`wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
|
|
290
|
-
"""
|
|
291
|
-
|
|
292
|
-
def offline_(self, value: bool = True):
|
|
293
|
-
self.offline = value
|
|
294
|
-
return self
|
|
295
|
-
|
|
296
|
-
def core_(self, value: bool = True):
|
|
297
|
-
self.use_wandb_core = value
|
|
298
|
-
return self
|
|
299
|
-
|
|
300
|
-
@override
|
|
301
|
-
def create_logger(self, root_config):
|
|
302
|
-
if not self.enabled:
|
|
303
|
-
return None
|
|
304
|
-
|
|
305
|
-
# If `wandb-core` is enabled, we should use the new backend.
|
|
306
|
-
if self.use_wandb_core:
|
|
307
|
-
try:
|
|
308
|
-
import wandb # type: ignore
|
|
309
|
-
|
|
310
|
-
# The minimum version that supports the new backend is 0.17.5
|
|
311
|
-
if pkg_resources.parse_version(
|
|
312
|
-
wandb.__version__
|
|
313
|
-
) < pkg_resources.parse_version("0.17.5"):
|
|
314
|
-
log.warning(
|
|
315
|
-
"The version of WandB installed does not support the `wandb-core` backend "
|
|
316
|
-
f"(expected version >= 0.17.5, found version {wandb.__version__}). "
|
|
317
|
-
"Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
|
|
318
|
-
)
|
|
319
|
-
else:
|
|
320
|
-
wandb.require("core")
|
|
321
|
-
log.critical("Using the `wandb-core` backend for WandB.")
|
|
322
|
-
except ImportError:
|
|
323
|
-
pass
|
|
324
|
-
|
|
325
|
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
|
326
|
-
|
|
327
|
-
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
328
|
-
root_config.id,
|
|
329
|
-
self,
|
|
330
|
-
)
|
|
331
|
-
return WandbLogger(
|
|
332
|
-
save_dir=save_dir,
|
|
333
|
-
project=self.project or _project_name(root_config),
|
|
334
|
-
name=root_config.run_name,
|
|
335
|
-
version=root_config.id,
|
|
336
|
-
log_model=self.log_model,
|
|
337
|
-
notes=(
|
|
338
|
-
"\n".join(f"- {note}" for note in root_config.notes)
|
|
339
|
-
if root_config.notes
|
|
340
|
-
else None
|
|
341
|
-
),
|
|
342
|
-
tags=root_config.tags,
|
|
343
|
-
offline=self.offline,
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
@override
|
|
347
|
-
def create_callbacks(self, root_config):
|
|
348
|
-
if self.watch:
|
|
349
|
-
yield from self.watch.create_callbacks(root_config)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
class CSVLoggerConfig(BaseLoggerConfig):
|
|
353
|
-
name: Literal["csv"] = "csv"
|
|
354
|
-
|
|
355
|
-
enabled: bool = True
|
|
356
|
-
"""Enable CSV logging."""
|
|
357
|
-
|
|
358
|
-
priority: int = 0
|
|
359
|
-
"""Priority of the logger. Higher priority loggers are created first."""
|
|
360
|
-
|
|
361
|
-
prefix: str = ""
|
|
362
|
-
"""A string to put at the beginning of metric keys."""
|
|
363
|
-
|
|
364
|
-
flush_logs_every_n_steps: int = 100
|
|
365
|
-
"""How often to flush logs to disk."""
|
|
366
|
-
|
|
367
|
-
@override
|
|
368
|
-
def create_logger(self, root_config):
|
|
369
|
-
if not self.enabled:
|
|
370
|
-
return None
|
|
371
|
-
|
|
372
|
-
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
|
373
|
-
|
|
374
|
-
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
375
|
-
root_config.id,
|
|
376
|
-
self,
|
|
377
|
-
)
|
|
378
|
-
return CSVLogger(
|
|
379
|
-
save_dir=save_dir,
|
|
380
|
-
name=root_config.run_name,
|
|
381
|
-
version=root_config.id,
|
|
382
|
-
prefix=self.prefix,
|
|
383
|
-
flush_logs_every_n_steps=self.flush_logs_every_n_steps,
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
def _tensorboard_available():
|
|
388
|
-
try:
|
|
389
|
-
from lightning.fabric.loggers.tensorboard import (
|
|
390
|
-
_TENSORBOARD_AVAILABLE,
|
|
391
|
-
_TENSORBOARDX_AVAILABLE,
|
|
392
|
-
)
|
|
393
|
-
|
|
394
|
-
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
|
|
395
|
-
log.warning(
|
|
396
|
-
"TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. "
|
|
397
|
-
"Please install TensorBoard with `pip install tensorboard` or "
|
|
398
|
-
"TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging."
|
|
399
|
-
)
|
|
400
|
-
return False
|
|
401
|
-
return True
|
|
402
|
-
except ImportError:
|
|
403
|
-
return False
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
407
|
-
name: Literal["tensorboard"] = "tensorboard"
|
|
408
|
-
|
|
409
|
-
enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
|
|
410
|
-
"""Enable TensorBoard logging."""
|
|
411
|
-
|
|
412
|
-
priority: int = 2
|
|
413
|
-
"""Priority of the logger. Higher priority loggers are created first."""
|
|
414
|
-
|
|
415
|
-
log_graph: bool = False
|
|
416
|
-
"""
|
|
417
|
-
Adds the computational graph to tensorboard. This requires that
|
|
418
|
-
the user has defined the `self.example_input_array` attribute in their
|
|
419
|
-
model.
|
|
420
|
-
"""
|
|
421
|
-
|
|
422
|
-
default_hp_metric: bool = True
|
|
423
|
-
"""
|
|
424
|
-
Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
|
|
425
|
-
called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
|
|
426
|
-
"""
|
|
427
|
-
|
|
428
|
-
prefix: str = ""
|
|
429
|
-
"""A string to put at the beginning of metric keys."""
|
|
430
|
-
|
|
431
|
-
@override
|
|
432
|
-
def create_logger(self, root_config):
|
|
433
|
-
if not self.enabled:
|
|
434
|
-
return None
|
|
435
|
-
|
|
436
|
-
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
|
437
|
-
|
|
438
|
-
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
439
|
-
root_config.id,
|
|
440
|
-
self,
|
|
441
|
-
)
|
|
442
|
-
return TensorBoardLogger(
|
|
443
|
-
save_dir=save_dir,
|
|
444
|
-
name=root_config.run_name,
|
|
445
|
-
version=root_config.id,
|
|
446
|
-
log_graph=self.log_graph,
|
|
447
|
-
default_hp_metric=self.default_hp_metric,
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
LoggerConfig: TypeAlias = Annotated[
|
|
452
|
-
WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
|
|
453
|
-
C.Field(discriminator="name"),
|
|
454
|
-
]
|
|
455
|
-
|
|
456
|
-
|
|
457
211
|
class LoggingConfig(CallbackConfigBase):
|
|
458
212
|
enabled: bool = True
|
|
459
213
|
"""Enable experiment tracking."""
|
|
@@ -474,29 +228,32 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
474
228
|
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
|
475
229
|
|
|
476
230
|
@property
|
|
477
|
-
def wandb(self)
|
|
231
|
+
def wandb(self):
|
|
478
232
|
return next(
|
|
479
233
|
(
|
|
480
234
|
logger
|
|
481
235
|
for logger in self.loggers
|
|
482
236
|
if isinstance(logger, WandbLoggerConfig)
|
|
483
237
|
),
|
|
238
|
+
None,
|
|
484
239
|
)
|
|
485
240
|
|
|
486
241
|
@property
|
|
487
|
-
def csv(self)
|
|
242
|
+
def csv(self):
|
|
488
243
|
return next(
|
|
489
244
|
(logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
|
|
245
|
+
None,
|
|
490
246
|
)
|
|
491
247
|
|
|
492
248
|
@property
|
|
493
|
-
def tensorboard(self)
|
|
249
|
+
def tensorboard(self):
|
|
494
250
|
return next(
|
|
495
251
|
(
|
|
496
252
|
logger
|
|
497
253
|
for logger in self.loggers
|
|
498
254
|
if isinstance(logger, TensorboardLoggerConfig)
|
|
499
255
|
),
|
|
256
|
+
None,
|
|
500
257
|
)
|
|
501
258
|
|
|
502
259
|
def create_loggers(self, root_config: "BaseConfig"):
|
|
@@ -509,9 +266,8 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
509
266
|
Returns:
|
|
510
267
|
list[Logger]: A list of constructed loggers.
|
|
511
268
|
"""
|
|
512
|
-
loggers: list[Logger] = []
|
|
513
269
|
if not self.enabled:
|
|
514
|
-
return
|
|
270
|
+
return
|
|
515
271
|
|
|
516
272
|
for logger_config in sorted(
|
|
517
273
|
self.loggers,
|
|
@@ -522,8 +278,7 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
522
278
|
continue
|
|
523
279
|
if (logger := logger_config.create_logger(root_config)) is None:
|
|
524
280
|
continue
|
|
525
|
-
|
|
526
|
-
return loggers
|
|
281
|
+
yield logger
|
|
527
282
|
|
|
528
283
|
@override
|
|
529
284
|
def create_callbacks(self, root_config):
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -244,7 +244,7 @@ class Trainer(LightningTrainer):
|
|
|
244
244
|
log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
|
|
245
245
|
kwargs["logger"] = False
|
|
246
246
|
else:
|
|
247
|
-
_update_kwargs(logger=config.trainer.logging.create_loggers(config))
|
|
247
|
+
_update_kwargs(logger=list(config.trainer.logging.create_loggers(config)))
|
|
248
248
|
|
|
249
249
|
if config.trainer.auto_determine_num_nodes:
|
|
250
250
|
# When num_nodes is auto, we need to detect the number of nodes.
|
|
@@ -15,7 +15,7 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=yBUZ63P4a3m5jviIUGxBd2WXq_TigpKt
|
|
|
15
15
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
16
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
|
-
nshtrainer/callbacks/early_stopping.py,sha256=
|
|
18
|
+
nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
|
|
19
19
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
20
20
|
nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
|
|
21
21
|
nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
|
|
@@ -29,7 +29,7 @@ nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsi
|
|
|
29
29
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
30
30
|
nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
|
|
31
31
|
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
32
|
-
nshtrainer/ll/__init__.py,sha256=
|
|
32
|
+
nshtrainer/ll/__init__.py,sha256=6UTt2apSD8tOZw3M7hyd-33v4RKSpNNATlWFbW4cNnU,2523
|
|
33
33
|
nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
|
|
34
34
|
nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
|
|
35
35
|
nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
|
|
@@ -46,15 +46,20 @@ nshtrainer/ll/snoop.py,sha256=hG9VCdm8mIZytHLZgUKRSoWqg55rVvUpBAH-OSiwgvI,36
|
|
|
46
46
|
nshtrainer/ll/trainer.py,sha256=hkn2xPtrSPQ7LqQhbyAKuMfNyHdhqB9bDPvgRCK8oJM,47
|
|
47
47
|
nshtrainer/ll/typecheck.py,sha256=ryV1Tzcf7hJ4I19H1oQVkikU9spmRk8jyIKQZ5UF7pQ,62
|
|
48
48
|
nshtrainer/ll/util.py,sha256=PQGu5Ff1raizxjXdm2rFFu4Mo816dBIYmItjkJX1_qk,44
|
|
49
|
+
nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfEk,448
|
|
50
|
+
nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
|
|
51
|
+
nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
|
|
52
|
+
nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
|
|
53
|
+
nshtrainer/loggers/wandb.py,sha256=QteSPfaBWxQ-FuuaF6_rKxw8twJdg_6oHoTQlbqvrXk,5079
|
|
49
54
|
nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
|
|
50
55
|
nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
|
|
51
56
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
|
|
52
57
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
|
|
53
58
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
54
59
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
55
|
-
nshtrainer/model/__init__.py,sha256=
|
|
60
|
+
nshtrainer/model/__init__.py,sha256=BmqSbf6v6oyeilti4iEn_Tyrr1kRmcFcJekTb8NeglI,1315
|
|
56
61
|
nshtrainer/model/base.py,sha256=AXRfEsFAT0Ln7zjYVPU5NgtHS_c8FZM-M4pyLamO7OA,17516
|
|
57
|
-
nshtrainer/model/config.py,sha256=
|
|
62
|
+
nshtrainer/model/config.py,sha256=D6Y-Y7GoMrpo7A2dmIqJsqc4X2IHwyl9OEHxO4uOc0g,42918
|
|
58
63
|
nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
|
|
59
64
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
60
65
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -74,7 +79,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
74
79
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
75
80
|
nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
|
|
76
81
|
nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
|
|
77
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
82
|
+
nshtrainer/trainer/trainer.py,sha256=M97phnALfG18VxkMLoDr5AKFf4UaPBdc6S2BghdBtas,17103
|
|
78
83
|
nshtrainer/util/_environment_info.py,sha256=yPtAbgjCY4tkvh5wp9sjNsF0Z45TYwzEAM_N2_b5BbY,23123
|
|
79
84
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
80
85
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
@@ -82,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
82
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
83
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
84
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
85
|
-
nshtrainer-0.
|
|
86
|
-
nshtrainer-0.
|
|
87
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.14.1.dist-info/METADATA,sha256=V7kssNBEV-JqILCZbRujHMs4ORomy_hWU4vRK0Je2GU,860
|
|
91
|
+
nshtrainer-0.14.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.14.1.dist-info/RECORD,,
|
|
File without changes
|