nshtrainer 0.13.1__tar.gz → 0.14.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/early_stopping.py +1 -3
  4. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/__init__.py +0 -2
  5. nshtrainer-0.14.1/src/nshtrainer/loggers/__init__.py +13 -0
  6. nshtrainer-0.14.1/src/nshtrainer/loggers/_base.py +26 -0
  7. nshtrainer-0.14.1/src/nshtrainer/loggers/csv.py +40 -0
  8. nshtrainer-0.14.1/src/nshtrainer/loggers/tensorboard.py +73 -0
  9. nshtrainer-0.14.1/src/nshtrainer/loggers/wandb.py +163 -0
  10. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/__init__.py +0 -2
  11. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/config.py +14 -259
  12. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/trainer/trainer.py +1 -1
  13. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/README.md +0 -0
  14. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/__init__.py +0 -0
  15. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
  16. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  17. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  18. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  19. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  20. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  21. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  22. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  23. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  24. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  25. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/base.py +0 -0
  26. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  27. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  28. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  29. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  30. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  31. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/ema.py +0 -0
  32. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  33. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  34. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/interval.py +0 -0
  35. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  36. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  37. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  38. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  39. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/timer.py +0 -0
  40. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  41. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/data/__init__.py +0 -0
  42. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  43. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/data/transform.py +0 -0
  44. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/_experimental.py +0 -0
  45. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/actsave.py +0 -0
  46. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/callbacks.py +0 -0
  47. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/config.py +0 -0
  48. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/data.py +0 -0
  49. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/log.py +0 -0
  50. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  51. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/model.py +0 -0
  52. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/nn.py +0 -0
  53. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/optimizer.py +0 -0
  54. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/runner.py +0 -0
  55. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/snapshot.py +0 -0
  56. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/snoop.py +0 -0
  57. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/trainer.py +0 -0
  58. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/typecheck.py +0 -0
  59. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/ll/util.py +0 -0
  60. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  61. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  62. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  63. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  64. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/metrics/__init__.py +0 -0
  65. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/metrics/_config.py +0 -0
  66. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/base.py +0 -0
  67. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/modules/callback.py +0 -0
  68. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/modules/debug.py +0 -0
  69. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/modules/distributed.py +0 -0
  70. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/modules/logger.py +0 -0
  71. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/modules/profiler.py +0 -0
  72. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  73. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  74. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/nn/__init__.py +0 -0
  75. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/nn/mlp.py +0 -0
  76. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/nn/module_dict.py +0 -0
  77. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/nn/module_list.py +0 -0
  78. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  79. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/optimizer.py +0 -0
  80. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/runner.py +0 -0
  81. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  82. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/trainer/__init__.py +0 -0
  83. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  84. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  85. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  86. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/util/_environment_info.py +0 -0
  87. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/util/_useful_types.py +0 -0
  88. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/util/environment.py +0 -0
  89. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/util/seed.py +0 -0
  90. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/util/slurm.py +0 -0
  91. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/util/typed.py +0 -0
  92. {nshtrainer-0.13.1 → nshtrainer-0.14.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.13.1
3
+ Version: 0.14.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.13.1"
3
+ version = "0.14.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -73,9 +73,7 @@ class EarlyStopping(_EarlyStopping):
73
73
 
74
74
  @override
75
75
  @staticmethod
76
- def _log_info(
77
- trainer: Trainer | None, message: str, log_rank_zero_only: bool
78
- ) -> None:
76
+ def _log_info(trainer: Trainer | None, message: str, log_rank_zero_only: bool):
79
77
  rank = _get_rank()
80
78
  if trainer is not None and trainer.world_size <= 1:
81
79
  rank = None
@@ -23,7 +23,6 @@ from .log import pretty as pretty
23
23
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
24
  from .model import Base as Base
25
25
  from .model import BaseConfig as BaseConfig
26
- from .model import BaseLoggerConfig as BaseLoggerConfig
27
26
  from .model import BaseProfilerConfig as BaseProfilerConfig
28
27
  from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
29
28
  from .model import CheckpointSavingConfig as CheckpointSavingConfig
@@ -48,7 +47,6 @@ from .model import PrimaryMetricConfig as PrimaryMetricConfig
48
47
  from .model import ReproducibilityConfig as ReproducibilityConfig
49
48
  from .model import SanityCheckingConfig as SanityCheckingConfig
50
49
  from .model import TrainerConfig as TrainerConfig
51
- from .model import WandbWatchConfig as WandbWatchConfig
52
50
  from .nn import TypedModuleDict as TypedModuleDict
53
51
  from .nn import TypedModuleList as TypedModuleList
54
52
  from .optimizer import OptimizerConfig as OptimizerConfig
@@ -0,0 +1,13 @@
1
+ from typing import Annotated, TypeAlias
2
+
3
+ import nshconfig as C
4
+
5
+ from ._base import BaseLoggerConfig as BaseLoggerConfig
6
+ from .csv import CSVLoggerConfig as CSVLoggerConfig
7
+ from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
8
+ from .wandb import WandbLoggerConfig as WandbLoggerConfig
9
+
10
+ LoggerConfig: TypeAlias = Annotated[
11
+ CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig,
12
+ C.Field(discriminator="name"),
13
+ ]
@@ -0,0 +1,26 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING
3
+
4
+ import nshconfig as C
5
+ from lightning.pytorch.loggers import Logger
6
+
7
+ if TYPE_CHECKING:
8
+ from ..model import BaseConfig
9
+
10
+
11
+ class BaseLoggerConfig(C.Config, ABC):
12
+ enabled: bool = True
13
+ """Enable this logger."""
14
+
15
+ priority: int = 0
16
+ """Priority of the logger. Higher priority loggers are created first."""
17
+
18
+ log_dir: C.DirectoryPath | None = None
19
+ """Directory to save the logs to. If None, will use the default log directory for the trainer."""
20
+
21
+ @abstractmethod
22
+ def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
23
+
24
+ def disable_(self):
25
+ self.enabled = False
26
+ return self
@@ -0,0 +1,40 @@
1
+ from typing import Literal
2
+
3
+ from typing_extensions import override
4
+
5
+ from ._base import BaseLoggerConfig
6
+
7
+
8
+ class CSVLoggerConfig(BaseLoggerConfig):
9
+ name: Literal["csv"] = "csv"
10
+
11
+ enabled: bool = True
12
+ """Enable CSV logging."""
13
+
14
+ priority: int = 0
15
+ """Priority of the logger. Higher priority loggers are created first."""
16
+
17
+ prefix: str = ""
18
+ """A string to put at the beginning of metric keys."""
19
+
20
+ flush_logs_every_n_steps: int = 100
21
+ """How often to flush logs to disk."""
22
+
23
+ @override
24
+ def create_logger(self, root_config):
25
+ if not self.enabled:
26
+ return None
27
+
28
+ from lightning.pytorch.loggers.csv_logs import CSVLogger
29
+
30
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
31
+ root_config.id,
32
+ self,
33
+ )
34
+ return CSVLogger(
35
+ save_dir=save_dir,
36
+ name=root_config.run_name,
37
+ version=root_config.id,
38
+ prefix=self.prefix,
39
+ flush_logs_every_n_steps=self.flush_logs_every_n_steps,
40
+ )
@@ -0,0 +1,73 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ import nshconfig as C
5
+ from typing_extensions import override
6
+
7
+ from ._base import BaseLoggerConfig
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ def _tensorboard_available():
13
+ try:
14
+ from lightning.fabric.loggers.tensorboard import (
15
+ _TENSORBOARD_AVAILABLE,
16
+ _TENSORBOARDX_AVAILABLE,
17
+ )
18
+
19
+ if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
20
+ log.warning(
21
+ "TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. "
22
+ "Please install TensorBoard with `pip install tensorboard` or "
23
+ "TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging."
24
+ )
25
+ return False
26
+ return True
27
+ except ImportError:
28
+ return False
29
+
30
+
31
+ class TensorboardLoggerConfig(BaseLoggerConfig):
32
+ name: Literal["tensorboard"] = "tensorboard"
33
+
34
+ enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
35
+ """Enable TensorBoard logging."""
36
+
37
+ priority: int = 2
38
+ """Priority of the logger. Higher priority loggers are created first."""
39
+
40
+ log_graph: bool = False
41
+ """
42
+ Adds the computational graph to tensorboard. This requires that
43
+ the user has defined the `self.example_input_array` attribute in their
44
+ model.
45
+ """
46
+
47
+ default_hp_metric: bool = True
48
+ """
49
+ Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
50
+ called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
51
+ """
52
+
53
+ prefix: str = ""
54
+ """A string to put at the beginning of metric keys."""
55
+
56
+ @override
57
+ def create_logger(self, root_config):
58
+ if not self.enabled:
59
+ return None
60
+
61
+ from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
62
+
63
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
64
+ root_config.id,
65
+ self,
66
+ )
67
+ return TensorBoardLogger(
68
+ save_dir=save_dir,
69
+ name=root_config.run_name,
70
+ version=root_config.id,
71
+ log_graph=self.log_graph,
72
+ default_hp_metric=self.default_hp_metric,
73
+ )
@@ -0,0 +1,163 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Literal
3
+
4
+ import nshconfig as C
5
+ import pkg_resources
6
+ from lightning.pytorch import Callback, LightningModule, Trainer
7
+ from typing_extensions import override
8
+
9
+ from ..callbacks import WandbWatchConfig
10
+ from ..callbacks.base import CallbackConfigBase
11
+ from ._base import BaseLoggerConfig
12
+
13
+ if TYPE_CHECKING:
14
+ from ..model.config import BaseConfig
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ def _project_name(
20
+ root_config: "BaseConfig",
21
+ default_project: str = "lightning_logs",
22
+ ):
23
+ # If the config has a project name, use that.
24
+ if project := root_config.project:
25
+ return project
26
+
27
+ # Otherwise, we should use the name of the module that the config is defined in,
28
+ # if we can find it.
29
+ # If this isn't in a module, use the default project name.
30
+ if not (module := root_config.__module__):
31
+ return default_project
32
+
33
+ # If the module is a package, use the package name.
34
+ if not (module := module.split(".", maxsplit=1)[0].strip()):
35
+ return default_project
36
+
37
+ return module
38
+
39
+
40
+ def _wandb_available():
41
+ try:
42
+ from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
43
+
44
+ if not _WANDB_AVAILABLE:
45
+ log.warning("WandB not found. Disabling WandbLogger.")
46
+ return False
47
+ return True
48
+ except ImportError:
49
+ return False
50
+
51
+
52
+ class FinishWandbOnTeardownCallback(Callback):
53
+ @override
54
+ def teardown(
55
+ self,
56
+ trainer: Trainer,
57
+ pl_module: LightningModule,
58
+ stage: str,
59
+ ):
60
+ try:
61
+ import wandb # type: ignore
62
+ except ImportError:
63
+ return
64
+
65
+ if wandb.run is None:
66
+ return
67
+
68
+ wandb.finish()
69
+
70
+
71
+ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
72
+ name: Literal["wandb"] = "wandb"
73
+
74
+ enabled: bool = C.Field(default_factory=lambda: _wandb_available())
75
+ """Enable WandB logging."""
76
+
77
+ priority: int = 2
78
+ """Priority of the logger. Higher priority loggers are created first,
79
+ and the highest priority logger is the "main" logger for PyTorch Lightning."""
80
+
81
+ project: str | None = None
82
+ """WandB project name to use for the logger. If None, will use the root config's project name."""
83
+
84
+ log_model: bool | Literal["all"] = False
85
+ """
86
+ Whether to log the model checkpoints to wandb.
87
+ Valid values are:
88
+ - False: Do not log the model checkpoints.
89
+ - True: Log the latest model checkpoint.
90
+ - "all": Log all model checkpoints.
91
+ """
92
+
93
+ watch: WandbWatchConfig | None = WandbWatchConfig()
94
+ """WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
95
+
96
+ offline: bool = False
97
+ """Whether to run WandB in offline mode."""
98
+
99
+ use_wandb_core: bool = True
100
+ """Whether to use the new `wandb-core` backend for WandB.
101
+ `wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
102
+ """
103
+
104
+ def offline_(self, value: bool = True):
105
+ self.offline = value
106
+ return self
107
+
108
+ def core_(self, value: bool = True):
109
+ self.use_wandb_core = value
110
+ return self
111
+
112
+ @override
113
+ def create_logger(self, root_config):
114
+ if not self.enabled:
115
+ return None
116
+
117
+ # If `wandb-core` is enabled, we should use the new backend.
118
+ if self.use_wandb_core:
119
+ try:
120
+ import wandb # type: ignore
121
+
122
+ # The minimum version that supports the new backend is 0.17.5
123
+ if pkg_resources.parse_version(
124
+ wandb.__version__
125
+ ) < pkg_resources.parse_version("0.17.5"):
126
+ log.warning(
127
+ "The version of WandB installed does not support the `wandb-core` backend "
128
+ f"(expected version >= 0.17.5, found version {wandb.__version__}). "
129
+ "Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
130
+ )
131
+ else:
132
+ wandb.require("core")
133
+ log.critical("Using the `wandb-core` backend for WandB.")
134
+ except ImportError:
135
+ pass
136
+
137
+ from lightning.pytorch.loggers.wandb import WandbLogger
138
+
139
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
140
+ root_config.id,
141
+ self,
142
+ )
143
+ return WandbLogger(
144
+ save_dir=save_dir,
145
+ project=self.project or _project_name(root_config),
146
+ name=root_config.run_name,
147
+ version=root_config.id,
148
+ log_model=self.log_model,
149
+ notes=(
150
+ "\n".join(f"- {note}" for note in root_config.notes)
151
+ if root_config.notes
152
+ else None
153
+ ),
154
+ tags=root_config.tags,
155
+ offline=self.offline,
156
+ )
157
+
158
+ @override
159
+ def create_callbacks(self, root_config):
160
+ yield FinishWandbOnTeardownCallback()
161
+
162
+ if self.watch:
163
+ yield from self.watch.create_callbacks(root_config)
@@ -3,7 +3,6 @@ from typing_extensions import TypeAlias
3
3
  from .base import Base as Base
4
4
  from .base import LightningModuleBase as LightningModuleBase
5
5
  from .config import BaseConfig as BaseConfig
6
- from .config import BaseLoggerConfig as BaseLoggerConfig
7
6
  from .config import BaseProfilerConfig as BaseProfilerConfig
8
7
  from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
8
  from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
@@ -22,6 +21,5 @@ from .config import PrimaryMetricConfig as PrimaryMetricConfig
22
21
  from .config import ReproducibilityConfig as ReproducibilityConfig
23
22
  from .config import SanityCheckingConfig as SanityCheckingConfig
24
23
  from .config import TrainerConfig as TrainerConfig
25
- from .config import WandbWatchConfig as WandbWatchConfig
26
24
 
27
25
  ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
@@ -19,7 +19,6 @@ from typing import (
19
19
 
20
20
  import nshconfig as C
21
21
  import numpy as np
22
- import pkg_resources
23
22
  import torch
24
23
  from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
25
24
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
@@ -31,7 +30,6 @@ from lightning.pytorch.plugins.layer_sync import LayerSync
31
30
  from lightning.pytorch.plugins.precision.precision import Precision
32
31
  from lightning.pytorch.profilers import Profiler
33
32
  from lightning.pytorch.strategies.strategy import Strategy
34
- from pydantic import DirectoryPath
35
33
  from typing_extensions import Self, TypedDict, TypeVar, override
36
34
 
37
35
  from .._checkpoint.loader import CheckpointLoadingConfig
@@ -41,9 +39,14 @@ from ..callbacks import (
41
39
  EarlyStoppingConfig,
42
40
  LastCheckpointCallbackConfig,
43
41
  OnExceptionCheckpointCallbackConfig,
44
- WandbWatchConfig,
45
42
  )
46
43
  from ..callbacks.base import CallbackConfigBase
44
+ from ..loggers import (
45
+ CSVLoggerConfig,
46
+ LoggerConfig,
47
+ TensorboardLoggerConfig,
48
+ WandbLoggerConfig,
49
+ )
47
50
  from ..metrics import MetricConfig
48
51
  from ..util._environment_info import EnvironmentConfig
49
52
 
@@ -205,255 +208,6 @@ ProfilerConfig: TypeAlias = Annotated[
205
208
  ]
206
209
 
207
210
 
208
- class BaseLoggerConfig(C.Config, ABC):
209
- enabled: bool = True
210
- """Enable this logger."""
211
-
212
- priority: int = 0
213
- """Priority of the logger. Higher priority loggers are created first."""
214
-
215
- log_dir: DirectoryPath | None = None
216
- """Directory to save the logs to. If None, will use the default log directory for the trainer."""
217
-
218
- @abstractmethod
219
- def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
220
-
221
- def disable_(self):
222
- self.enabled = False
223
- return self
224
-
225
-
226
- def _project_name(
227
- root_config: "BaseConfig",
228
- default_project: str = "lightning_logs",
229
- ):
230
- # If the config has a project name, use that.
231
- if project := root_config.project:
232
- return project
233
-
234
- # Otherwise, we should use the name of the module that the config is defined in,
235
- # if we can find it.
236
- # If this isn't in a module, use the default project name.
237
- if not (module := root_config.__module__):
238
- return default_project
239
-
240
- # If the module is a package, use the package name.
241
- if not (module := module.split(".", maxsplit=1)[0].strip()):
242
- return default_project
243
-
244
- return module
245
-
246
-
247
- def _wandb_available():
248
- try:
249
- from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
250
-
251
- if not _WANDB_AVAILABLE:
252
- log.warning("WandB not found. Disabling WandbLogger.")
253
- return False
254
- return True
255
- except ImportError:
256
- return False
257
-
258
-
259
- class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
260
- name: Literal["wandb"] = "wandb"
261
-
262
- enabled: bool = C.Field(default_factory=lambda: _wandb_available())
263
- """Enable WandB logging."""
264
-
265
- priority: int = 2
266
- """Priority of the logger. Higher priority loggers are created first,
267
- and the highest priority logger is the "main" logger for PyTorch Lightning."""
268
-
269
- project: str | None = None
270
- """WandB project name to use for the logger. If None, will use the root config's project name."""
271
-
272
- log_model: bool | Literal["all"] = False
273
- """
274
- Whether to log the model checkpoints to wandb.
275
- Valid values are:
276
- - False: Do not log the model checkpoints.
277
- - True: Log the latest model checkpoint.
278
- - "all": Log all model checkpoints.
279
- """
280
-
281
- watch: WandbWatchConfig = WandbWatchConfig()
282
- """WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
283
-
284
- offline: bool = False
285
- """Whether to run WandB in offline mode."""
286
-
287
- use_wandb_core: bool = True
288
- """Whether to use the new `wandb-core` backend for WandB.
289
- `wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
290
- """
291
-
292
- def offline_(self, value: bool = True):
293
- self.offline = value
294
- return self
295
-
296
- def core_(self, value: bool = True):
297
- self.use_wandb_core = value
298
- return self
299
-
300
- @override
301
- def create_logger(self, root_config):
302
- if not self.enabled:
303
- return None
304
-
305
- # If `wandb-core` is enabled, we should use the new backend.
306
- if self.use_wandb_core:
307
- try:
308
- import wandb # type: ignore
309
-
310
- # The minimum version that supports the new backend is 0.17.5
311
- if pkg_resources.parse_version(
312
- wandb.__version__
313
- ) < pkg_resources.parse_version("0.17.5"):
314
- log.warning(
315
- "The version of WandB installed does not support the `wandb-core` backend "
316
- f"(expected version >= 0.17.5, found version {wandb.__version__}). "
317
- "Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
318
- )
319
- else:
320
- wandb.require("core")
321
- log.critical("Using the `wandb-core` backend for WandB.")
322
- except ImportError:
323
- pass
324
-
325
- from lightning.pytorch.loggers.wandb import WandbLogger
326
-
327
- save_dir = root_config.directory._resolve_log_directory_for_logger(
328
- root_config.id,
329
- self,
330
- )
331
- return WandbLogger(
332
- save_dir=save_dir,
333
- project=self.project or _project_name(root_config),
334
- name=root_config.run_name,
335
- version=root_config.id,
336
- log_model=self.log_model,
337
- notes=(
338
- "\n".join(f"- {note}" for note in root_config.notes)
339
- if root_config.notes
340
- else None
341
- ),
342
- tags=root_config.tags,
343
- offline=self.offline,
344
- )
345
-
346
- @override
347
- def create_callbacks(self, root_config):
348
- if self.watch:
349
- yield from self.watch.create_callbacks(root_config)
350
-
351
-
352
- class CSVLoggerConfig(BaseLoggerConfig):
353
- name: Literal["csv"] = "csv"
354
-
355
- enabled: bool = True
356
- """Enable CSV logging."""
357
-
358
- priority: int = 0
359
- """Priority of the logger. Higher priority loggers are created first."""
360
-
361
- prefix: str = ""
362
- """A string to put at the beginning of metric keys."""
363
-
364
- flush_logs_every_n_steps: int = 100
365
- """How often to flush logs to disk."""
366
-
367
- @override
368
- def create_logger(self, root_config):
369
- if not self.enabled:
370
- return None
371
-
372
- from lightning.pytorch.loggers.csv_logs import CSVLogger
373
-
374
- save_dir = root_config.directory._resolve_log_directory_for_logger(
375
- root_config.id,
376
- self,
377
- )
378
- return CSVLogger(
379
- save_dir=save_dir,
380
- name=root_config.run_name,
381
- version=root_config.id,
382
- prefix=self.prefix,
383
- flush_logs_every_n_steps=self.flush_logs_every_n_steps,
384
- )
385
-
386
-
387
- def _tensorboard_available():
388
- try:
389
- from lightning.fabric.loggers.tensorboard import (
390
- _TENSORBOARD_AVAILABLE,
391
- _TENSORBOARDX_AVAILABLE,
392
- )
393
-
394
- if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
395
- log.warning(
396
- "TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. "
397
- "Please install TensorBoard with `pip install tensorboard` or "
398
- "TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging."
399
- )
400
- return False
401
- return True
402
- except ImportError:
403
- return False
404
-
405
-
406
- class TensorboardLoggerConfig(BaseLoggerConfig):
407
- name: Literal["tensorboard"] = "tensorboard"
408
-
409
- enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
410
- """Enable TensorBoard logging."""
411
-
412
- priority: int = 2
413
- """Priority of the logger. Higher priority loggers are created first."""
414
-
415
- log_graph: bool = False
416
- """
417
- Adds the computational graph to tensorboard. This requires that
418
- the user has defined the `self.example_input_array` attribute in their
419
- model.
420
- """
421
-
422
- default_hp_metric: bool = True
423
- """
424
- Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
425
- called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
426
- """
427
-
428
- prefix: str = ""
429
- """A string to put at the beginning of metric keys."""
430
-
431
- @override
432
- def create_logger(self, root_config):
433
- if not self.enabled:
434
- return None
435
-
436
- from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
437
-
438
- save_dir = root_config.directory._resolve_log_directory_for_logger(
439
- root_config.id,
440
- self,
441
- )
442
- return TensorBoardLogger(
443
- save_dir=save_dir,
444
- name=root_config.run_name,
445
- version=root_config.id,
446
- log_graph=self.log_graph,
447
- default_hp_metric=self.default_hp_metric,
448
- )
449
-
450
-
451
- LoggerConfig: TypeAlias = Annotated[
452
- WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
453
- C.Field(discriminator="name"),
454
- ]
455
-
456
-
457
211
  class LoggingConfig(CallbackConfigBase):
458
212
  enabled: bool = True
459
213
  """Enable experiment tracking."""
@@ -474,29 +228,32 @@ class LoggingConfig(CallbackConfigBase):
474
228
  """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
475
229
 
476
230
  @property
477
- def wandb(self) -> WandbLoggerConfig | None:
231
+ def wandb(self):
478
232
  return next(
479
233
  (
480
234
  logger
481
235
  for logger in self.loggers
482
236
  if isinstance(logger, WandbLoggerConfig)
483
237
  ),
238
+ None,
484
239
  )
485
240
 
486
241
  @property
487
- def csv(self) -> CSVLoggerConfig | None:
242
+ def csv(self):
488
243
  return next(
489
244
  (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
245
+ None,
490
246
  )
491
247
 
492
248
  @property
493
- def tensorboard(self) -> TensorboardLoggerConfig | None:
249
+ def tensorboard(self):
494
250
  return next(
495
251
  (
496
252
  logger
497
253
  for logger in self.loggers
498
254
  if isinstance(logger, TensorboardLoggerConfig)
499
255
  ),
256
+ None,
500
257
  )
501
258
 
502
259
  def create_loggers(self, root_config: "BaseConfig"):
@@ -509,9 +266,8 @@ class LoggingConfig(CallbackConfigBase):
509
266
  Returns:
510
267
  list[Logger]: A list of constructed loggers.
511
268
  """
512
- loggers: list[Logger] = []
513
269
  if not self.enabled:
514
- return loggers
270
+ return
515
271
 
516
272
  for logger_config in sorted(
517
273
  self.loggers,
@@ -522,8 +278,7 @@ class LoggingConfig(CallbackConfigBase):
522
278
  continue
523
279
  if (logger := logger_config.create_logger(root_config)) is None:
524
280
  continue
525
- loggers.append(logger)
526
- return loggers
281
+ yield logger
527
282
 
528
283
  @override
529
284
  def create_callbacks(self, root_config):
@@ -244,7 +244,7 @@ class Trainer(LightningTrainer):
244
244
  log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
245
245
  kwargs["logger"] = False
246
246
  else:
247
- _update_kwargs(logger=config.trainer.logging.create_loggers(config))
247
+ _update_kwargs(logger=list(config.trainer.logging.create_loggers(config)))
248
248
 
249
249
  if config.trainer.auto_determine_num_nodes:
250
250
  # When num_nodes is auto, we need to detect the number of nodes.
File without changes