nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -0,0 +1,86 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ def _trainer(module: Any):
12
+ if torch.jit.is_scripting():
13
+ return None
14
+
15
+ if hasattr(module, "_trainer"):
16
+ trainer = module._trainer
17
+ else:
18
+ try:
19
+ trainer = module.trainer
20
+ except RuntimeError:
21
+ return None
22
+
23
+ from ...trainer import Trainer
24
+
25
+ if not isinstance(trainer, Trainer):
26
+ return None
27
+
28
+ return trainer
29
+
30
+
31
+ class _DebugModuleMixin:
32
+ @property
33
+ def nshtrainer_or_none(self):
34
+ return _trainer(self)
35
+
36
+ @property
37
+ def nshtrainer(self):
38
+ if (trainer := _trainer(self)) is None:
39
+ raise RuntimeError("Could not resolve trainer.")
40
+ return trainer
41
+
42
+ @property
43
+ def debug(self) -> bool:
44
+ if (trainer := _trainer(self)) is None:
45
+ return False
46
+ return trainer.debug
47
+
48
+ @debug.setter
49
+ def debug(self, value: bool):
50
+ if (trainer := _trainer(self)) is None:
51
+ return
52
+ trainer.debug = value
53
+
54
+ @torch.jit.unused
55
+ def breakpoint(self, rank_zero_only: bool = True):
56
+ if (
57
+ not rank_zero_only
58
+ or not torch.distributed.is_initialized()
59
+ or torch.distributed.get_rank() == 0
60
+ ):
61
+ breakpoint()
62
+
63
+ if rank_zero_only and torch.distributed.is_initialized():
64
+ _ = torch.distributed.barrier()
65
+
66
+ @torch.jit.unused
67
+ def ensure_finite(
68
+ self,
69
+ tensor: torch.Tensor,
70
+ name: str | None = None,
71
+ throw: bool = False,
72
+ ):
73
+ name_parts: list[str] = ["Tensor"]
74
+ if name is not None:
75
+ name_parts.append(name)
76
+ name = " ".join(name_parts)
77
+
78
+ not_finite = ~torch.isfinite(tensor)
79
+ if not_finite.any():
80
+ msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
81
+ if throw:
82
+ raise RuntimeError(msg)
83
+ else:
84
+ log.warning(msg)
85
+ return False
86
+ return True
@@ -1,166 +1,163 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
4
+ import dataclasses
3
5
  from collections import deque
4
6
  from collections.abc import Callable, Generator
5
7
  from contextlib import contextmanager
6
- from dataclasses import dataclass, field
7
- from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, cast
8
+ from typing import Any, ClassVar
9
9
 
10
- import torchmetrics
11
10
  from lightning.pytorch import LightningModule
12
11
  from lightning.pytorch.utilities.types import _METRIC
13
12
  from lightning_utilities.core.rank_zero import rank_zero_warn
14
- from nshutils import ActSave
15
- from typing_extensions import override
13
+ from typing_extensions import Self, override
16
14
 
17
15
  from ...util.typing_utils import mixin_base_type
18
- from ..config import BaseConfig
19
16
 
20
17
 
21
- @dataclass(frozen=True, kw_only=True)
22
- class _LogContext:
18
+ @dataclasses.dataclass(frozen=True, kw_only=True)
19
+ class _LogContextKwargs:
20
+ __ignore_fields__: ClassVar[set[str]] = {"prefix", "disabled"}
21
+
23
22
  prefix: str | None = None
24
23
  disabled: bool | None = None
25
- kwargs: dict[str, Any] = field(default_factory=dict)
26
-
27
-
28
- class LoggerModuleMixin(mixin_base_type(LightningModule)):
29
- @property
30
- def log_dir(self):
31
- """
32
- The directory where logs are saved.
33
- """
34
- if (trainer := self._trainer) is None:
35
- raise RuntimeError("trainer is not defined")
36
-
37
- if (logger := trainer.logger) is None:
38
- raise RuntimeError("trainer.logger is not defined")
39
-
40
- if (log_dir := logger.log_dir) is None:
41
- raise RuntimeError("trainer.logger.log_dir is not defined")
42
-
43
- return Path(log_dir)
44
-
45
- @property
46
- def should_update_logs(self):
47
- """
48
- Whether logs should be updated. This is true once every `log_every_n_steps` steps.
49
- """
50
- if self._trainer is None:
51
- raise RuntimeError(
52
- "`should_update_logs` can only be used after the module is attached to a trainer"
53
- )
54
-
55
- return self._trainer._logger_connector.should_update_logs
56
-
57
-
58
- class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningModule)):
24
+ prog_bar: bool | None = None
25
+ logger: bool | None = None
26
+ on_step: bool | None = None
27
+ on_epoch: bool | None = None
28
+ reduce_fx: str | Callable | None = None
29
+ enable_graph: bool | None = None
30
+ sync_dist: bool | None = None
31
+ sync_dist_group: Any | None = None
32
+ add_dataloader_idx: bool | None = None
33
+ batch_size: int | None = None
34
+ rank_zero_only: bool | None = None
35
+
36
+ def copy_from(self, other: Self):
37
+ kwargs = copy.deepcopy(self)
38
+
39
+ # Copy over all the not-None values from the other object
40
+ updates = {}
41
+ for field in dataclasses.fields(self):
42
+ # Ignore disabled fields
43
+ if field.name in self.__ignore_fields__:
44
+ continue
45
+
46
+ if (value := getattr(other, field.name, None)) is None:
47
+ continue
48
+ # setattr(kwargs, field.name, value)
49
+ updates[field.name] = value
50
+
51
+ return dataclasses.replace(kwargs, **updates)
52
+
53
+ def to_dict(self):
54
+ d = dataclasses.asdict(self)
55
+ for field in self.__ignore_fields__:
56
+ d.pop(field, None)
57
+ return d
58
+
59
+
60
+ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
59
61
  @override
60
62
  def __init__(self, *args, **kwargs):
61
63
  super().__init__(*args, **kwargs)
62
64
 
63
- self._logger_prefix_stack = deque[_LogContext]()
64
-
65
- if TYPE_CHECKING:
66
-
67
- @contextmanager
68
- def log_context(
69
- self,
70
- prefix: str | None = None,
71
- *,
72
- disabled: bool | None = None,
73
- prog_bar: bool | None = None,
74
- logger: bool | None = None,
75
- on_step: bool | None = None,
76
- on_epoch: bool | None = None,
77
- reduce_fx: str | Callable | None = None,
78
- enable_graph: bool | None = None,
79
- sync_dist: bool | None = None,
80
- sync_dist_group: Any | None = None,
81
- add_dataloader_idx: bool | None = None,
82
- batch_size: int | None = None,
83
- rank_zero_only: bool | None = None,
84
- ) -> Generator[None, None, None]: ...
85
-
86
- else:
87
-
88
- @contextmanager
89
- def log_context(
90
- self, prefix: str | None = None, *, disabled: bool | None = None, **kwargs
91
- ) -> Generator[None, None, None]:
92
- self._logger_prefix_stack.append(
93
- _LogContext(
94
- prefix=prefix,
95
- disabled=disabled,
96
- kwargs=kwargs,
97
- )
65
+ self._logger_prefix_stack = deque[_LogContextKwargs]()
66
+
67
+ @contextmanager
68
+ def log_context(
69
+ self,
70
+ prefix: str | None = None,
71
+ disabled: bool | None = None,
72
+ prog_bar: bool | None = None,
73
+ logger: bool | None = None,
74
+ on_step: bool | None = None,
75
+ on_epoch: bool | None = None,
76
+ reduce_fx: str | Callable | None = None,
77
+ enable_graph: bool | None = None,
78
+ sync_dist: bool | None = None,
79
+ sync_dist_group: Any | None = None,
80
+ add_dataloader_idx: bool | None = None,
81
+ batch_size: int | None = None,
82
+ rank_zero_only: bool | None = None,
83
+ ) -> Generator[None, None, None]:
84
+ self._logger_prefix_stack.append(
85
+ _LogContextKwargs(
86
+ prefix=prefix,
87
+ disabled=disabled,
88
+ prog_bar=prog_bar,
89
+ logger=logger,
90
+ on_step=on_step,
91
+ on_epoch=on_epoch,
92
+ reduce_fx=reduce_fx,
93
+ enable_graph=enable_graph,
94
+ sync_dist=sync_dist,
95
+ sync_dist_group=sync_dist_group,
96
+ add_dataloader_idx=add_dataloader_idx,
97
+ batch_size=batch_size,
98
+ rank_zero_only=rank_zero_only,
98
99
  )
99
- try:
100
- yield
101
- finally:
102
- _ = self._logger_prefix_stack.pop()
103
-
104
- if TYPE_CHECKING:
105
-
106
- @override
107
- def log( # type: ignore[override]
108
- self,
109
- name: str,
110
- value: _METRIC,
111
- *,
112
- prog_bar: bool = False,
113
- logger: bool | None = None,
114
- on_step: bool | None = None,
115
- on_epoch: bool | None = None,
116
- reduce_fx: str | Callable = "mean",
117
- enable_graph: bool = False,
118
- sync_dist: bool = False,
119
- sync_dist_group: Any | None = None,
120
- add_dataloader_idx: bool = True,
121
- batch_size: int | None = None,
122
- metric_attribute: str | None = None,
123
- rank_zero_only: bool = False,
124
- ) -> None: ...
125
-
126
- else:
127
-
128
- @override
129
- def log(self, name: str, value: _METRIC, **kwargs) -> None:
130
- # join all prefixes
131
- prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
132
- name = f"{prefix}{name}"
133
-
134
- # check for disabled context:
135
- # if the topmost non-null context is disabled, then we don't log
136
- for c in reversed(self._logger_prefix_stack):
137
- if c.disabled is not None:
138
- if c.disabled:
139
- rank_zero_warn(
140
- f"Skipping logging of {name} due to disabled context"
141
- )
142
- return
143
- else:
144
- break
145
-
146
- fn_kwargs = {}
147
- for c in self._logger_prefix_stack:
148
- fn_kwargs.update(c.kwargs)
149
- fn_kwargs.update(kwargs)
150
-
151
- self._logger_actsave(name, value)
152
-
153
- return super().log(name, value, **fn_kwargs)
154
-
155
- def _logger_actsave(self, name: str, value: _METRIC) -> None:
156
- hparams = cast(BaseConfig, self.hparams)
157
- if not hparams.trainer.logging.actsave_logged_metrics:
158
- return
159
-
160
- ActSave.save(
161
- lambda: {
162
- f"logger.{name}": lambda: value.compute()
163
- if isinstance(value, torchmetrics.Metric)
164
- else value
165
- }
100
+ )
101
+ try:
102
+ yield
103
+ finally:
104
+ _ = self._logger_prefix_stack.pop()
105
+
106
+ @override
107
+ def log(
108
+ self,
109
+ name: str,
110
+ value: _METRIC,
111
+ prog_bar: bool = False,
112
+ logger: bool | None = None,
113
+ on_step: bool | None = None,
114
+ on_epoch: bool | None = None,
115
+ reduce_fx: str | Callable = "mean",
116
+ enable_graph: bool = False,
117
+ sync_dist: bool = False,
118
+ sync_dist_group: Any | None = None,
119
+ add_dataloader_idx: bool = True,
120
+ batch_size: int | None = None,
121
+ metric_attribute: str | None = None,
122
+ rank_zero_only: bool = False,
123
+ ) -> None:
124
+ # join all prefixes
125
+ prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
126
+ name = f"{prefix}{name}"
127
+
128
+ # check for disabled context:
129
+ # if the topmost non-null context is disabled, then we don't log
130
+ for c in reversed(self._logger_prefix_stack):
131
+ if c.disabled is not None:
132
+ if c.disabled:
133
+ rank_zero_warn(
134
+ f"Skipping logging of {name} due to disabled context"
135
+ )
136
+ return
137
+ else:
138
+ break
139
+
140
+ fn_kwargs = _LogContextKwargs()
141
+ for c in self._logger_prefix_stack:
142
+ fn_kwargs = fn_kwargs.copy_from(c)
143
+ fn_kwargs = fn_kwargs.copy_from(
144
+ _LogContextKwargs(
145
+ prog_bar=prog_bar,
146
+ logger=logger,
147
+ on_step=on_step,
148
+ on_epoch=on_epoch,
149
+ reduce_fx=reduce_fx,
150
+ enable_graph=enable_graph,
151
+ sync_dist=sync_dist,
152
+ sync_dist_group=sync_dist_group,
153
+ add_dataloader_idx=add_dataloader_idx,
154
+ batch_size=batch_size,
155
+ rank_zero_only=rank_zero_only,
156
+ )
157
+ )
158
+ return super().log(
159
+ name,
160
+ value,
161
+ metric_attribute=metric_attribute,
162
+ **fn_kwargs.to_dict(),
166
163
  )
@@ -9,7 +9,7 @@ import nshconfig as C
9
9
  from lightning.pytorch.profilers import Profiler
10
10
 
11
11
  if TYPE_CHECKING:
12
- from ..model import BaseConfig
12
+ from ..trainer._config import TrainerConfig
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -28,4 +28,4 @@ class BaseProfilerConfig(C.Config, ABC):
28
28
  """
29
29
 
30
30
  @abstractmethod
31
- def create_profiler(self, root_config: "BaseConfig") -> Profiler | None: ...
31
+ def create_profiler(self, trainer_config: TrainerConfig) -> Profiler | None: ...
@@ -21,16 +21,16 @@ class AdvancedProfilerConfig(BaseProfilerConfig):
21
21
  """
22
22
 
23
23
  @override
24
- def create_profiler(self, root_config):
24
+ def create_profiler(self, trainer_config):
25
25
  from lightning.pytorch.profilers.advanced import AdvancedProfiler
26
26
 
27
27
  if (dirpath := self.dirpath) is None:
28
- dirpath = root_config.directory.resolve_subdirectory(
29
- root_config.id, "profile"
28
+ dirpath = trainer_config.directory.resolve_subdirectory(
29
+ trainer_config.id, "profile"
30
30
  )
31
31
 
32
32
  if (filename := self.filename) is None:
33
- filename = f"{root_config.id}_profile.txt"
33
+ filename = f"{trainer_config.id}_profile.txt"
34
34
 
35
35
  return AdvancedProfiler(
36
36
  line_count_restriction=self.line_count_restriction,
@@ -60,16 +60,16 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
60
60
  """Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
61
61
 
62
62
  @override
63
- def create_profiler(self, root_config):
63
+ def create_profiler(self, trainer_config):
64
64
  from lightning.pytorch.profilers.pytorch import PyTorchProfiler
65
65
 
66
66
  if (dirpath := self.dirpath) is None:
67
- dirpath = root_config.directory.resolve_subdirectory(
68
- root_config.id, "profile"
67
+ dirpath = trainer_config.directory.resolve_subdirectory(
68
+ trainer_config.id, "profile"
69
69
  )
70
70
 
71
71
  if (filename := self.filename) is None:
72
- filename = f"{root_config.id}_profile.txt"
72
+ filename = f"{trainer_config.id}_profile.txt"
73
73
 
74
74
  return PyTorchProfiler(
75
75
  group_by_input_shapes=self.group_by_input_shapes,
@@ -20,16 +20,16 @@ class SimpleProfilerConfig(BaseProfilerConfig):
20
20
  """
21
21
 
22
22
  @override
23
- def create_profiler(self, root_config):
23
+ def create_profiler(self, trainer_config):
24
24
  from lightning.pytorch.profilers.simple import SimpleProfiler
25
25
 
26
26
  if (dirpath := self.dirpath) is None:
27
- dirpath = root_config.directory.resolve_subdirectory(
28
- root_config.id, "profile"
27
+ dirpath = trainer_config.directory.resolve_subdirectory(
28
+ trainer_config.id, "profile"
29
29
  )
30
30
 
31
31
  if (filename := self.filename) is None:
32
- filename = f"{root_config.id}_profile.txt"
32
+ filename = f"{trainer_config.id}_profile.txt"
33
33
 
34
34
  return SimpleProfiler(
35
35
  extended=self.extended,
@@ -1,3 +1,4 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from ._config import TrainerConfig as TrainerConfig
3
4
  from .trainer import Trainer as Trainer