nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +52 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.1.dist-info/RECORD +0 -162
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -0,0 +1,86 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
log = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
def _trainer(module: Any):
|
12
|
+
if torch.jit.is_scripting():
|
13
|
+
return None
|
14
|
+
|
15
|
+
if hasattr(module, "_trainer"):
|
16
|
+
trainer = module._trainer
|
17
|
+
else:
|
18
|
+
try:
|
19
|
+
trainer = module.trainer
|
20
|
+
except RuntimeError:
|
21
|
+
return None
|
22
|
+
|
23
|
+
from ...trainer import Trainer
|
24
|
+
|
25
|
+
if not isinstance(trainer, Trainer):
|
26
|
+
return None
|
27
|
+
|
28
|
+
return trainer
|
29
|
+
|
30
|
+
|
31
|
+
class _DebugModuleMixin:
|
32
|
+
@property
|
33
|
+
def nshtrainer_or_none(self):
|
34
|
+
return _trainer(self)
|
35
|
+
|
36
|
+
@property
|
37
|
+
def nshtrainer(self):
|
38
|
+
if (trainer := _trainer(self)) is None:
|
39
|
+
raise RuntimeError("Could not resolve trainer.")
|
40
|
+
return trainer
|
41
|
+
|
42
|
+
@property
|
43
|
+
def debug(self) -> bool:
|
44
|
+
if (trainer := _trainer(self)) is None:
|
45
|
+
return False
|
46
|
+
return trainer.debug
|
47
|
+
|
48
|
+
@debug.setter
|
49
|
+
def debug(self, value: bool):
|
50
|
+
if (trainer := _trainer(self)) is None:
|
51
|
+
return
|
52
|
+
trainer.debug = value
|
53
|
+
|
54
|
+
@torch.jit.unused
|
55
|
+
def breakpoint(self, rank_zero_only: bool = True):
|
56
|
+
if (
|
57
|
+
not rank_zero_only
|
58
|
+
or not torch.distributed.is_initialized()
|
59
|
+
or torch.distributed.get_rank() == 0
|
60
|
+
):
|
61
|
+
breakpoint()
|
62
|
+
|
63
|
+
if rank_zero_only and torch.distributed.is_initialized():
|
64
|
+
_ = torch.distributed.barrier()
|
65
|
+
|
66
|
+
@torch.jit.unused
|
67
|
+
def ensure_finite(
|
68
|
+
self,
|
69
|
+
tensor: torch.Tensor,
|
70
|
+
name: str | None = None,
|
71
|
+
throw: bool = False,
|
72
|
+
):
|
73
|
+
name_parts: list[str] = ["Tensor"]
|
74
|
+
if name is not None:
|
75
|
+
name_parts.append(name)
|
76
|
+
name = " ".join(name_parts)
|
77
|
+
|
78
|
+
not_finite = ~torch.isfinite(tensor)
|
79
|
+
if not_finite.any():
|
80
|
+
msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
|
81
|
+
if throw:
|
82
|
+
raise RuntimeError(msg)
|
83
|
+
else:
|
84
|
+
log.warning(msg)
|
85
|
+
return False
|
86
|
+
return True
|
@@ -1,166 +1,163 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copy
|
4
|
+
import dataclasses
|
3
5
|
from collections import deque
|
4
6
|
from collections.abc import Callable, Generator
|
5
7
|
from contextlib import contextmanager
|
6
|
-
from
|
7
|
-
from pathlib import Path
|
8
|
-
from typing import TYPE_CHECKING, Any, cast
|
8
|
+
from typing import Any, ClassVar
|
9
9
|
|
10
|
-
import torchmetrics
|
11
10
|
from lightning.pytorch import LightningModule
|
12
11
|
from lightning.pytorch.utilities.types import _METRIC
|
13
12
|
from lightning_utilities.core.rank_zero import rank_zero_warn
|
14
|
-
from
|
15
|
-
from typing_extensions import override
|
13
|
+
from typing_extensions import Self, override
|
16
14
|
|
17
15
|
from ...util.typing_utils import mixin_base_type
|
18
|
-
from ..config import BaseConfig
|
19
16
|
|
20
17
|
|
21
|
-
@dataclass(frozen=True, kw_only=True)
|
22
|
-
class
|
18
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
19
|
+
class _LogContextKwargs:
|
20
|
+
__ignore_fields__: ClassVar[set[str]] = {"prefix", "disabled"}
|
21
|
+
|
23
22
|
prefix: str | None = None
|
24
23
|
disabled: bool | None = None
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
24
|
+
prog_bar: bool | None = None
|
25
|
+
logger: bool | None = None
|
26
|
+
on_step: bool | None = None
|
27
|
+
on_epoch: bool | None = None
|
28
|
+
reduce_fx: str | Callable | None = None
|
29
|
+
enable_graph: bool | None = None
|
30
|
+
sync_dist: bool | None = None
|
31
|
+
sync_dist_group: Any | None = None
|
32
|
+
add_dataloader_idx: bool | None = None
|
33
|
+
batch_size: int | None = None
|
34
|
+
rank_zero_only: bool | None = None
|
35
|
+
|
36
|
+
def copy_from(self, other: Self):
|
37
|
+
kwargs = copy.deepcopy(self)
|
38
|
+
|
39
|
+
# Copy over all the not-None values from the other object
|
40
|
+
updates = {}
|
41
|
+
for field in dataclasses.fields(self):
|
42
|
+
# Ignore disabled fields
|
43
|
+
if field.name in self.__ignore_fields__:
|
44
|
+
continue
|
45
|
+
|
46
|
+
if (value := getattr(other, field.name, None)) is None:
|
47
|
+
continue
|
48
|
+
# setattr(kwargs, field.name, value)
|
49
|
+
updates[field.name] = value
|
50
|
+
|
51
|
+
return dataclasses.replace(kwargs, **updates)
|
52
|
+
|
53
|
+
def to_dict(self):
|
54
|
+
d = dataclasses.asdict(self)
|
55
|
+
for field in self.__ignore_fields__:
|
56
|
+
d.pop(field, None)
|
57
|
+
return d
|
58
|
+
|
59
|
+
|
60
|
+
class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
|
59
61
|
@override
|
60
62
|
def __init__(self, *args, **kwargs):
|
61
63
|
super().__init__(*args, **kwargs)
|
62
64
|
|
63
|
-
self._logger_prefix_stack = deque[
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
)
|
65
|
+
self._logger_prefix_stack = deque[_LogContextKwargs]()
|
66
|
+
|
67
|
+
@contextmanager
|
68
|
+
def log_context(
|
69
|
+
self,
|
70
|
+
prefix: str | None = None,
|
71
|
+
disabled: bool | None = None,
|
72
|
+
prog_bar: bool | None = None,
|
73
|
+
logger: bool | None = None,
|
74
|
+
on_step: bool | None = None,
|
75
|
+
on_epoch: bool | None = None,
|
76
|
+
reduce_fx: str | Callable | None = None,
|
77
|
+
enable_graph: bool | None = None,
|
78
|
+
sync_dist: bool | None = None,
|
79
|
+
sync_dist_group: Any | None = None,
|
80
|
+
add_dataloader_idx: bool | None = None,
|
81
|
+
batch_size: int | None = None,
|
82
|
+
rank_zero_only: bool | None = None,
|
83
|
+
) -> Generator[None, None, None]:
|
84
|
+
self._logger_prefix_stack.append(
|
85
|
+
_LogContextKwargs(
|
86
|
+
prefix=prefix,
|
87
|
+
disabled=disabled,
|
88
|
+
prog_bar=prog_bar,
|
89
|
+
logger=logger,
|
90
|
+
on_step=on_step,
|
91
|
+
on_epoch=on_epoch,
|
92
|
+
reduce_fx=reduce_fx,
|
93
|
+
enable_graph=enable_graph,
|
94
|
+
sync_dist=sync_dist,
|
95
|
+
sync_dist_group=sync_dist_group,
|
96
|
+
add_dataloader_idx=add_dataloader_idx,
|
97
|
+
batch_size=batch_size,
|
98
|
+
rank_zero_only=rank_zero_only,
|
98
99
|
)
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
f"logger.{name}": lambda: value.compute()
|
163
|
-
if isinstance(value, torchmetrics.Metric)
|
164
|
-
else value
|
165
|
-
}
|
100
|
+
)
|
101
|
+
try:
|
102
|
+
yield
|
103
|
+
finally:
|
104
|
+
_ = self._logger_prefix_stack.pop()
|
105
|
+
|
106
|
+
@override
|
107
|
+
def log(
|
108
|
+
self,
|
109
|
+
name: str,
|
110
|
+
value: _METRIC,
|
111
|
+
prog_bar: bool = False,
|
112
|
+
logger: bool | None = None,
|
113
|
+
on_step: bool | None = None,
|
114
|
+
on_epoch: bool | None = None,
|
115
|
+
reduce_fx: str | Callable = "mean",
|
116
|
+
enable_graph: bool = False,
|
117
|
+
sync_dist: bool = False,
|
118
|
+
sync_dist_group: Any | None = None,
|
119
|
+
add_dataloader_idx: bool = True,
|
120
|
+
batch_size: int | None = None,
|
121
|
+
metric_attribute: str | None = None,
|
122
|
+
rank_zero_only: bool = False,
|
123
|
+
) -> None:
|
124
|
+
# join all prefixes
|
125
|
+
prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
|
126
|
+
name = f"{prefix}{name}"
|
127
|
+
|
128
|
+
# check for disabled context:
|
129
|
+
# if the topmost non-null context is disabled, then we don't log
|
130
|
+
for c in reversed(self._logger_prefix_stack):
|
131
|
+
if c.disabled is not None:
|
132
|
+
if c.disabled:
|
133
|
+
rank_zero_warn(
|
134
|
+
f"Skipping logging of {name} due to disabled context"
|
135
|
+
)
|
136
|
+
return
|
137
|
+
else:
|
138
|
+
break
|
139
|
+
|
140
|
+
fn_kwargs = _LogContextKwargs()
|
141
|
+
for c in self._logger_prefix_stack:
|
142
|
+
fn_kwargs = fn_kwargs.copy_from(c)
|
143
|
+
fn_kwargs = fn_kwargs.copy_from(
|
144
|
+
_LogContextKwargs(
|
145
|
+
prog_bar=prog_bar,
|
146
|
+
logger=logger,
|
147
|
+
on_step=on_step,
|
148
|
+
on_epoch=on_epoch,
|
149
|
+
reduce_fx=reduce_fx,
|
150
|
+
enable_graph=enable_graph,
|
151
|
+
sync_dist=sync_dist,
|
152
|
+
sync_dist_group=sync_dist_group,
|
153
|
+
add_dataloader_idx=add_dataloader_idx,
|
154
|
+
batch_size=batch_size,
|
155
|
+
rank_zero_only=rank_zero_only,
|
156
|
+
)
|
157
|
+
)
|
158
|
+
return super().log(
|
159
|
+
name,
|
160
|
+
value,
|
161
|
+
metric_attribute=metric_attribute,
|
162
|
+
**fn_kwargs.to_dict(),
|
166
163
|
)
|
nshtrainer/profiler/_base.py
CHANGED
@@ -9,7 +9,7 @@ import nshconfig as C
|
|
9
9
|
from lightning.pytorch.profilers import Profiler
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
|
-
from ..
|
12
|
+
from ..trainer._config import TrainerConfig
|
13
13
|
|
14
14
|
log = logging.getLogger(__name__)
|
15
15
|
|
@@ -28,4 +28,4 @@ class BaseProfilerConfig(C.Config, ABC):
|
|
28
28
|
"""
|
29
29
|
|
30
30
|
@abstractmethod
|
31
|
-
def create_profiler(self,
|
31
|
+
def create_profiler(self, trainer_config: TrainerConfig) -> Profiler | None: ...
|
nshtrainer/profiler/advanced.py
CHANGED
@@ -21,16 +21,16 @@ class AdvancedProfilerConfig(BaseProfilerConfig):
|
|
21
21
|
"""
|
22
22
|
|
23
23
|
@override
|
24
|
-
def create_profiler(self,
|
24
|
+
def create_profiler(self, trainer_config):
|
25
25
|
from lightning.pytorch.profilers.advanced import AdvancedProfiler
|
26
26
|
|
27
27
|
if (dirpath := self.dirpath) is None:
|
28
|
-
dirpath =
|
29
|
-
|
28
|
+
dirpath = trainer_config.directory.resolve_subdirectory(
|
29
|
+
trainer_config.id, "profile"
|
30
30
|
)
|
31
31
|
|
32
32
|
if (filename := self.filename) is None:
|
33
|
-
filename = f"{
|
33
|
+
filename = f"{trainer_config.id}_profile.txt"
|
34
34
|
|
35
35
|
return AdvancedProfiler(
|
36
36
|
line_count_restriction=self.line_count_restriction,
|
nshtrainer/profiler/pytorch.py
CHANGED
@@ -60,16 +60,16 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
|
|
60
60
|
"""Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
|
61
61
|
|
62
62
|
@override
|
63
|
-
def create_profiler(self,
|
63
|
+
def create_profiler(self, trainer_config):
|
64
64
|
from lightning.pytorch.profilers.pytorch import PyTorchProfiler
|
65
65
|
|
66
66
|
if (dirpath := self.dirpath) is None:
|
67
|
-
dirpath =
|
68
|
-
|
67
|
+
dirpath = trainer_config.directory.resolve_subdirectory(
|
68
|
+
trainer_config.id, "profile"
|
69
69
|
)
|
70
70
|
|
71
71
|
if (filename := self.filename) is None:
|
72
|
-
filename = f"{
|
72
|
+
filename = f"{trainer_config.id}_profile.txt"
|
73
73
|
|
74
74
|
return PyTorchProfiler(
|
75
75
|
group_by_input_shapes=self.group_by_input_shapes,
|
nshtrainer/profiler/simple.py
CHANGED
@@ -20,16 +20,16 @@ class SimpleProfilerConfig(BaseProfilerConfig):
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
@override
|
23
|
-
def create_profiler(self,
|
23
|
+
def create_profiler(self, trainer_config):
|
24
24
|
from lightning.pytorch.profilers.simple import SimpleProfiler
|
25
25
|
|
26
26
|
if (dirpath := self.dirpath) is None:
|
27
|
-
dirpath =
|
28
|
-
|
27
|
+
dirpath = trainer_config.directory.resolve_subdirectory(
|
28
|
+
trainer_config.id, "profile"
|
29
29
|
)
|
30
30
|
|
31
31
|
if (filename := self.filename) is None:
|
32
|
-
filename = f"{
|
32
|
+
filename = f"{trainer_config.id}_profile.txt"
|
33
33
|
|
34
34
|
return SimpleProfiler(
|
35
35
|
extended=self.extended,
|
nshtrainer/trainer/__init__.py
CHANGED