fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
fkat/pytorch/loggers.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any, Protocol, TYPE_CHECKING
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
import lightning as L
|
|
7
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
8
|
+
from mlflow.entities import Metric, RunTag, Param
|
|
9
|
+
from mlflow.tracking import MlflowClient # type: ignore[possibly-unbound-import]
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lightning.pytorch.loggers import MLFlowLogger, TensorBoardLogger, WandbLogger
|
|
13
|
+
|
|
14
|
+
from fkat.utils import assert_not_none
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _is_logger_type(logger: Any, logger_name: str) -> bool:
|
|
18
|
+
"""Check if logger matches type from lightning or pytorch_lightning."""
|
|
19
|
+
module = type(logger).__module__
|
|
20
|
+
return type(logger).__name__ == logger_name and (
|
|
21
|
+
module.startswith("lightning.pytorch.loggers") or module.startswith("pytorch_lightning.loggers")
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LightningLogger(Protocol):
|
|
26
|
+
"""
|
|
27
|
+
Protocol defining the interface for logging that handle metrics, tags, and artifacts.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def tags(self) -> dict[str, Any]:
|
|
31
|
+
"""Get current tags"""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Log a single key-value tag.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
key (str): The identifier/name of the tag
|
|
40
|
+
value (str): The value associated with the tag
|
|
41
|
+
"""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
def log_batch(
|
|
45
|
+
self,
|
|
46
|
+
metrics: dict[str, float] | None = None,
|
|
47
|
+
params: dict[str, Any] | None = None,
|
|
48
|
+
tags: dict[str, str] | None = None,
|
|
49
|
+
timestamp: int | None = None,
|
|
50
|
+
step: int | None = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Log multiple metrics and/or tags in a single batch operation.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
metrics (dict[str, float], optional): Dictionary mapping metric names to their float values
|
|
57
|
+
params (dict[str, Any], optional): Dictionary mapping params names to their values
|
|
58
|
+
tags (dict[str, str], optional): Dictionary mapping tag names to their string values
|
|
59
|
+
timestamp (int, optional): Unix timestamp for when the batch was logged
|
|
60
|
+
step (int, optional): Training step or iteration number
|
|
61
|
+
"""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
65
|
+
"""
|
|
66
|
+
Log a local file as an artifact.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
local_path (str): Path to the file on the local filesystem to be logged
|
|
70
|
+
artifact_path (str, optional): Remote path where the artifact should be stored
|
|
71
|
+
If None, a default location should be used
|
|
72
|
+
"""
|
|
73
|
+
...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class MLFlowLogger:
|
|
77
|
+
"""
|
|
78
|
+
Mlflow logger class that supports rank_zero logging of tags, metrics and distributed
|
|
79
|
+
logging of artifacts.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
logger (MLFlowLogger): PTL MLFlow logger object
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
logger: "MLFlowLogger | None" = None,
|
|
88
|
+
client: MlflowClient | None = None,
|
|
89
|
+
synchronous: bool | None = None,
|
|
90
|
+
run_id: str | None = None,
|
|
91
|
+
) -> None:
|
|
92
|
+
super().__init__()
|
|
93
|
+
if logger:
|
|
94
|
+
self._client: MlflowClient = assert_not_none(getattr(logger, "_mlflow_client", None))
|
|
95
|
+
self._synchronous = getattr(logger, "_log_batch_kwargs", {}).get("synchronous")
|
|
96
|
+
self._run_id: str = assert_not_none(getattr(logger, "_run_id", None))
|
|
97
|
+
else:
|
|
98
|
+
assert client
|
|
99
|
+
self._client = client
|
|
100
|
+
self._synchronous = synchronous
|
|
101
|
+
assert run_id
|
|
102
|
+
self._run_id = run_id
|
|
103
|
+
|
|
104
|
+
@rank_zero_only
|
|
105
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
106
|
+
self._client.set_tag(run_id=self._run_id, key=key, value=value, synchronous=self._synchronous)
|
|
107
|
+
|
|
108
|
+
def tags(self) -> dict[str, Any]:
|
|
109
|
+
run = self._client.get_run(self._run_id)
|
|
110
|
+
return run.data.tags
|
|
111
|
+
|
|
112
|
+
@rank_zero_only
|
|
113
|
+
def log_batch(
|
|
114
|
+
self,
|
|
115
|
+
metrics: dict[str, float] | None = None,
|
|
116
|
+
params: dict[str, Any] | None = None,
|
|
117
|
+
tags: dict[str, str] | None = None,
|
|
118
|
+
timestamp: int | None = None,
|
|
119
|
+
step: int | None = None,
|
|
120
|
+
) -> None:
|
|
121
|
+
ms = [Metric(k, v, timestamp, step) for k, v in metrics.items()] if metrics else []
|
|
122
|
+
ps = [Param(k, v) for k, v in params.items()] if params else []
|
|
123
|
+
ts = [RunTag(k, v) for k, v in tags.items()] if tags else []
|
|
124
|
+
self._client.log_batch(
|
|
125
|
+
run_id=self._run_id,
|
|
126
|
+
metrics=ms,
|
|
127
|
+
params=ps,
|
|
128
|
+
tags=ts,
|
|
129
|
+
synchronous=self._synchronous,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
133
|
+
# TODO: log directly to s3 uri
|
|
134
|
+
# TODO: support async logging
|
|
135
|
+
self._client.log_artifact(
|
|
136
|
+
run_id=self._run_id,
|
|
137
|
+
local_path=local_path,
|
|
138
|
+
artifact_path=artifact_path,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class TensorBoardLogger(LightningLogger):
|
|
143
|
+
"""TensorBoard logger with rank_zero logging."""
|
|
144
|
+
|
|
145
|
+
def __init__(self, logger: "TensorBoardLogger") -> None:
|
|
146
|
+
self._logger = logger
|
|
147
|
+
|
|
148
|
+
@rank_zero_only
|
|
149
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
150
|
+
self._logger.experiment.add_text(key, value)
|
|
151
|
+
|
|
152
|
+
def tags(self) -> dict[str, Any]:
|
|
153
|
+
return {}
|
|
154
|
+
|
|
155
|
+
@rank_zero_only
|
|
156
|
+
def log_batch(
|
|
157
|
+
self,
|
|
158
|
+
metrics: dict[str, float] | None = None,
|
|
159
|
+
params: dict[str, Any] | None = None,
|
|
160
|
+
tags: dict[str, str] | None = None,
|
|
161
|
+
timestamp: int | None = None,
|
|
162
|
+
step: int | None = None,
|
|
163
|
+
) -> None:
|
|
164
|
+
if metrics:
|
|
165
|
+
for k, v in metrics.items():
|
|
166
|
+
self._logger.experiment.add_scalar(k, v, step)
|
|
167
|
+
if tags:
|
|
168
|
+
for k, v in tags.items():
|
|
169
|
+
self._logger.experiment.add_text(k, v, step)
|
|
170
|
+
|
|
171
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
172
|
+
from pathlib import Path
|
|
173
|
+
import shutil
|
|
174
|
+
|
|
175
|
+
dest = Path(self._logger.log_dir) / (artifact_path or Path(local_path).name)
|
|
176
|
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
177
|
+
shutil.copy2(local_path, dest)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class WandbLogger(LightningLogger):
|
|
181
|
+
"""WandB logger with rank_zero logging."""
|
|
182
|
+
|
|
183
|
+
def __init__(self, logger: "WandbLogger") -> None:
|
|
184
|
+
self._logger = logger
|
|
185
|
+
|
|
186
|
+
@rank_zero_only
|
|
187
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
188
|
+
self._logger.experiment.config.update({key: value})
|
|
189
|
+
|
|
190
|
+
def tags(self) -> dict[str, Any]:
|
|
191
|
+
return dict(self._logger.experiment.config)
|
|
192
|
+
|
|
193
|
+
@rank_zero_only
|
|
194
|
+
def log_batch(
|
|
195
|
+
self,
|
|
196
|
+
metrics: dict[str, float] | None = None,
|
|
197
|
+
params: dict[str, Any] | None = None,
|
|
198
|
+
tags: dict[str, str] | None = None,
|
|
199
|
+
timestamp: int | None = None,
|
|
200
|
+
step: int | None = None,
|
|
201
|
+
) -> None:
|
|
202
|
+
log_dict = {}
|
|
203
|
+
if metrics:
|
|
204
|
+
log_dict.update(metrics)
|
|
205
|
+
if tags:
|
|
206
|
+
log_dict.update(tags)
|
|
207
|
+
if log_dict:
|
|
208
|
+
self._logger.experiment.log(log_dict, step=step)
|
|
209
|
+
|
|
210
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
211
|
+
self._logger.experiment.save(local_path)
|
|
212
|
+
|
|
213
|
+
if self._logger.experiment.settings.mode == "offline":
|
|
214
|
+
import shutil
|
|
215
|
+
from pathlib import Path
|
|
216
|
+
|
|
217
|
+
src = Path(local_path).absolute()
|
|
218
|
+
dest = Path(self._logger.experiment.settings.files_dir) / src.name
|
|
219
|
+
|
|
220
|
+
if dest.is_symlink():
|
|
221
|
+
dest.unlink()
|
|
222
|
+
shutil.copy2(src, dest)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class CompositeLogger(LightningLogger):
|
|
226
|
+
"""
|
|
227
|
+
A wrapper on top of the collection of :class:`Logger` instances,
|
|
228
|
+
providing methods to log metrics, artifacts, and tags across all registered loggers
|
|
229
|
+
simultaneously.
|
|
230
|
+
|
|
231
|
+
Attributes:
|
|
232
|
+
loggers (list[LightningLogger]): List of loggers
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
trainer (L.Trainer): PyTorch Lightning trainer instance used to initialize loggers
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
loggers: list[LightningLogger]
|
|
239
|
+
|
|
240
|
+
def __init__(self, trainer: "L.Trainer | None", loggers: list[LightningLogger] | None = None) -> None:
|
|
241
|
+
if trainer:
|
|
242
|
+
self.loggers = []
|
|
243
|
+
for logger in trainer.loggers:
|
|
244
|
+
if _is_logger_type(logger, "MLFlowLogger"):
|
|
245
|
+
self.loggers.append(MLFlowLogger(logger=logger)) # type: ignore[arg-type]
|
|
246
|
+
elif _is_logger_type(logger, "TensorBoardLogger"):
|
|
247
|
+
self.loggers.append(TensorBoardLogger(logger=logger)) # type: ignore[arg-type]
|
|
248
|
+
elif _is_logger_type(logger, "WandbLogger"):
|
|
249
|
+
self.loggers.append(WandbLogger(logger=logger)) # type: ignore[arg-type]
|
|
250
|
+
else:
|
|
251
|
+
assert loggers
|
|
252
|
+
self.loggers = loggers
|
|
253
|
+
|
|
254
|
+
def __str__(self) -> str:
|
|
255
|
+
return str([type(obj).__name__ for obj in self.loggers])
|
|
256
|
+
|
|
257
|
+
@override
|
|
258
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
259
|
+
for logger in self.loggers:
|
|
260
|
+
logger.log_artifact(local_path=local_path, artifact_path=artifact_path)
|
|
261
|
+
|
|
262
|
+
@override
|
|
263
|
+
def log_batch(
|
|
264
|
+
self,
|
|
265
|
+
metrics: dict[str, float] | None = None,
|
|
266
|
+
params: dict[str, Any] | None = None,
|
|
267
|
+
tags: dict[str, str] | None = None,
|
|
268
|
+
timestamp: int | None = None,
|
|
269
|
+
step: int | None = None,
|
|
270
|
+
) -> None:
|
|
271
|
+
for logger in self.loggers:
|
|
272
|
+
logger.log_batch(metrics=metrics, tags=tags, timestamp=timestamp, step=step)
|
|
273
|
+
|
|
274
|
+
@override
|
|
275
|
+
def tags(self) -> dict[str, Any]:
|
|
276
|
+
tags = {}
|
|
277
|
+
for logger in self.loggers:
|
|
278
|
+
tags.update(logger.tags())
|
|
279
|
+
return tags
|
|
280
|
+
|
|
281
|
+
@override
|
|
282
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
283
|
+
for logger in self.loggers:
|
|
284
|
+
logger.log_tag(key=key, value=value)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from .base import (
|
|
4
|
+
Schedule,
|
|
5
|
+
Never,
|
|
6
|
+
Always,
|
|
7
|
+
Fixed,
|
|
8
|
+
Every,
|
|
9
|
+
Elapsed,
|
|
10
|
+
GlobalRank,
|
|
11
|
+
LocalRank,
|
|
12
|
+
InvertedSchedule,
|
|
13
|
+
CombinedSchedule,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"Schedule",
|
|
18
|
+
"Never",
|
|
19
|
+
"Always",
|
|
20
|
+
"Fixed",
|
|
21
|
+
"Every",
|
|
22
|
+
"Elapsed",
|
|
23
|
+
"GlobalRank",
|
|
24
|
+
"LocalRank",
|
|
25
|
+
"InvertedSchedule",
|
|
26
|
+
"CombinedSchedule",
|
|
27
|
+
]
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import datetime as dt
|
|
4
|
+
import operator
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
from functools import reduce
|
|
7
|
+
from typing import Protocol
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
import lightning as L
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Schedule(Protocol):
|
|
14
|
+
"""
|
|
15
|
+
Protocol defining a generic PyTorch schedule
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def check(
|
|
19
|
+
self,
|
|
20
|
+
*,
|
|
21
|
+
stage: str | None = None,
|
|
22
|
+
batch_idx: int | None = None,
|
|
23
|
+
step: int | None = None,
|
|
24
|
+
trainer: L.Trainer | None = None,
|
|
25
|
+
) -> bool:
|
|
26
|
+
"""Checks the schedule for the given moment.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
stage (str): current trainer stage. eg. train/test/validate/predict
|
|
30
|
+
batch_idx (Optional[int]): current batch_idx
|
|
31
|
+
step (Optional[int]): current step
|
|
32
|
+
trainer (Optional[Trainer]): lightning trainer of callback
|
|
33
|
+
Returns:
|
|
34
|
+
bool: True if schedule passed the check, False otherwise
|
|
35
|
+
"""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
def __and__(self, first: "Schedule", second: "Schedule") -> "CombinedSchedule":
|
|
39
|
+
return self._combine(operator.and_, first, second)
|
|
40
|
+
|
|
41
|
+
def __or__(self, first: "Schedule", second: "Schedule") -> "CombinedSchedule":
|
|
42
|
+
return self._combine(operator.or_, first, second)
|
|
43
|
+
|
|
44
|
+
def __invert__(self, other: "Schedule") -> "InvertedSchedule":
|
|
45
|
+
return InvertedSchedule(other)
|
|
46
|
+
|
|
47
|
+
def _combine(self, fn: Callable[[bool, bool], bool], first: "Schedule", second: "Schedule") -> "CombinedSchedule":
|
|
48
|
+
return CombinedSchedule(fn, (first, second))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class InvertedSchedule(Schedule):
|
|
52
|
+
def __init__(self, other: Schedule) -> None:
|
|
53
|
+
self.other = other
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def check(
|
|
57
|
+
self,
|
|
58
|
+
*,
|
|
59
|
+
stage: str | None = None,
|
|
60
|
+
batch_idx: int | None = None,
|
|
61
|
+
step: int | None = None,
|
|
62
|
+
trainer: L.Trainer | None = None,
|
|
63
|
+
) -> bool:
|
|
64
|
+
return not self.other.check(stage=stage, batch_idx=batch_idx, step=step, trainer=trainer)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class CombinedSchedule(Schedule):
|
|
68
|
+
def __init__(self, fn: Callable[[bool, bool], bool], schedules: Sequence[Schedule]) -> None:
|
|
69
|
+
self.fn = fn
|
|
70
|
+
self.schedules = schedules
|
|
71
|
+
|
|
72
|
+
@override
|
|
73
|
+
def check(
|
|
74
|
+
self,
|
|
75
|
+
*,
|
|
76
|
+
stage: str | None = None,
|
|
77
|
+
batch_idx: int | None = None,
|
|
78
|
+
step: int | None = None,
|
|
79
|
+
trainer: L.Trainer | None = None,
|
|
80
|
+
) -> bool:
|
|
81
|
+
return reduce(
|
|
82
|
+
self.fn, (s.check(stage=stage, batch_idx=batch_idx, step=step, trainer=trainer) for s in self.schedules)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class GlobalRank(Schedule):
|
|
87
|
+
"""
|
|
88
|
+
A schedule that only executes on specific global ranks in a distributed training setup.
|
|
89
|
+
|
|
90
|
+
This is useful for operations that should only be performed on certain ranks,
|
|
91
|
+
such as logging, checkpointing, or other operations that would be redundant
|
|
92
|
+
or conflicting if performed on all ranks.
|
|
93
|
+
|
|
94
|
+
Attributes:
|
|
95
|
+
ranks (tuple[int, ...]): The global ranks on which this schedule should execute.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(self, *ranks: int) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Initialize a GlobalRank schedule.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
*ranks: Variable number of integer ranks. The schedule will only execute
|
|
104
|
+
on these global ranks.
|
|
105
|
+
"""
|
|
106
|
+
self.ranks = ranks
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def check(
|
|
110
|
+
self,
|
|
111
|
+
*,
|
|
112
|
+
stage: str | None = None,
|
|
113
|
+
batch_idx: int | None = None,
|
|
114
|
+
step: int | None = None,
|
|
115
|
+
trainer: L.Trainer | None = None,
|
|
116
|
+
) -> bool:
|
|
117
|
+
"""
|
|
118
|
+
Check if the current global rank is in the specified ranks.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
stage: Current trainer stage (ignored by this schedule).
|
|
122
|
+
batch_idx: Current batch index (ignored by this schedule).
|
|
123
|
+
step: Current step (ignored by this schedule).
|
|
124
|
+
trainer: Lightning trainer instance, used to get the global rank.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
bool: True if the trainer's global rank is in the specified ranks, False otherwise.
|
|
128
|
+
Always returns False if trainer is None.
|
|
129
|
+
"""
|
|
130
|
+
if trainer is None:
|
|
131
|
+
return False
|
|
132
|
+
return trainer.global_rank in self.ranks
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class LocalRank(Schedule):
|
|
136
|
+
"""
|
|
137
|
+
A schedule that only executes on specific local ranks in a distributed training setup.
|
|
138
|
+
|
|
139
|
+
This is useful for node-specific operations that should only be performed on certain
|
|
140
|
+
ranks within each node, such as local logging or monitoring.
|
|
141
|
+
|
|
142
|
+
Attributes:
|
|
143
|
+
ranks (tuple[int, ...]): The local ranks on which this schedule should execute.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(self, *ranks: int) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Initialize a LocalRank schedule.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
*ranks: Variable number of integer ranks. The schedule will only execute
|
|
152
|
+
on these local ranks.
|
|
153
|
+
"""
|
|
154
|
+
self.ranks = ranks
|
|
155
|
+
|
|
156
|
+
@override
|
|
157
|
+
def check(
|
|
158
|
+
self,
|
|
159
|
+
*,
|
|
160
|
+
stage: str | None = None,
|
|
161
|
+
batch_idx: int | None = None,
|
|
162
|
+
step: int | None = None,
|
|
163
|
+
trainer: L.Trainer | None = None,
|
|
164
|
+
) -> bool:
|
|
165
|
+
"""
|
|
166
|
+
Check if the current local rank is in the specified ranks.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
stage: Current trainer stage (ignored by this schedule).
|
|
170
|
+
batch_idx: Current batch index (ignored by this schedule).
|
|
171
|
+
step: Current step (ignored by this schedule).
|
|
172
|
+
trainer: Lightning trainer instance, used to get the local rank.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
bool: True if the trainer's local rank is in the specified ranks, False otherwise.
|
|
176
|
+
Always returns False if trainer is None.
|
|
177
|
+
"""
|
|
178
|
+
if trainer is None:
|
|
179
|
+
return False
|
|
180
|
+
return trainer.local_rank in self.ranks
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class Never(Schedule):
|
|
184
|
+
"""
|
|
185
|
+
A schedule for an event that never happens.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
@override
|
|
189
|
+
def check(
|
|
190
|
+
self,
|
|
191
|
+
*,
|
|
192
|
+
stage: str | None = None,
|
|
193
|
+
batch_idx: int | None = None,
|
|
194
|
+
step: int | None = None,
|
|
195
|
+
trainer: L.Trainer | None = None,
|
|
196
|
+
) -> bool:
|
|
197
|
+
return False
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class Always(Schedule):
|
|
201
|
+
"""
|
|
202
|
+
A schedule for an event that always happens.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
@override
|
|
206
|
+
def check(
|
|
207
|
+
self,
|
|
208
|
+
*,
|
|
209
|
+
stage: str | None = None,
|
|
210
|
+
batch_idx: int | None = None,
|
|
211
|
+
step: int | None = None,
|
|
212
|
+
trainer: L.Trainer | None = None,
|
|
213
|
+
) -> bool:
|
|
214
|
+
return True
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class Fixed(Schedule):
|
|
218
|
+
"""
|
|
219
|
+
A schedule for an event that happens after a warmup period
|
|
220
|
+
and lasts for a fixed number of steps.
|
|
221
|
+
|
|
222
|
+
Attributes:
|
|
223
|
+
warmup_steps (int): Number of initial steps to skip before logging starts.
|
|
224
|
+
active_steps (int): Number of steps to log after warmup period.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
def __init__(self, warmup_steps: int, active_steps: int) -> None:
|
|
228
|
+
self._warmup_steps: int = warmup_steps
|
|
229
|
+
self._active_steps: int = active_steps
|
|
230
|
+
|
|
231
|
+
@override
|
|
232
|
+
def check(
|
|
233
|
+
self,
|
|
234
|
+
*,
|
|
235
|
+
stage: str | None = None,
|
|
236
|
+
batch_idx: int | None = None,
|
|
237
|
+
step: int | None = None,
|
|
238
|
+
trainer: L.Trainer | None = None,
|
|
239
|
+
) -> bool:
|
|
240
|
+
assert step is not None, "step must be provided"
|
|
241
|
+
if step < self._warmup_steps:
|
|
242
|
+
return False
|
|
243
|
+
|
|
244
|
+
if step - self._warmup_steps >= self._active_steps:
|
|
245
|
+
return False
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class Every(Schedule):
|
|
250
|
+
"""
|
|
251
|
+
A schedule for an event that happens every specified number of batches and/or steps.
|
|
252
|
+
|
|
253
|
+
Attributes:
|
|
254
|
+
n_batches (Optional[int]): A positive number of batches between logging events.
|
|
255
|
+
Defaults to 0 - use only n_steps
|
|
256
|
+
n_steps (Optional[int]): A positive number of (train) steps between logging events.
|
|
257
|
+
Defaults to 0 - use only n_batches
|
|
258
|
+
stage (Optional[str]): The stage this schedule applies to ('train', 'validation', 'test', 'predict').
|
|
259
|
+
If None, applies to all stages.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
def __init__(self, *, n_batches: int = 0, n_steps: int = 0, stage: str | None = None) -> None:
|
|
263
|
+
assert n_batches or n_steps, "either n_batches or n_steps has to be a positive number"
|
|
264
|
+
self._n_batches = n_batches
|
|
265
|
+
self._n_steps = n_steps
|
|
266
|
+
self._stage = stage
|
|
267
|
+
|
|
268
|
+
@override
|
|
269
|
+
def check(
|
|
270
|
+
self,
|
|
271
|
+
*,
|
|
272
|
+
stage: str | None = None,
|
|
273
|
+
batch_idx: int | None = None,
|
|
274
|
+
step: int | None = None,
|
|
275
|
+
trainer: L.Trainer | None = None,
|
|
276
|
+
) -> bool:
|
|
277
|
+
# If stage is specified and doesn't match, return False
|
|
278
|
+
if self._stage is not None and stage != self._stage:
|
|
279
|
+
return False
|
|
280
|
+
|
|
281
|
+
return (batch_idx is not None and self._n_batches > 0 and batch_idx % self._n_batches == 0) or (
|
|
282
|
+
step is not None and self._n_steps > 0 and step % self._n_steps == 0
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class Elapsed(Schedule):
|
|
287
|
+
"""
|
|
288
|
+
A schedule for an event that happens after the provided time interval has elapsed.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
def __init__(self, interval: dt.timedelta) -> None:
|
|
292
|
+
self.interval = interval
|
|
293
|
+
self.last_triggered: dt.datetime | None = None
|
|
294
|
+
|
|
295
|
+
@override
|
|
296
|
+
def check(
|
|
297
|
+
self,
|
|
298
|
+
*,
|
|
299
|
+
stage: str | None = None,
|
|
300
|
+
batch_idx: int | None = None,
|
|
301
|
+
step: int | None = None,
|
|
302
|
+
trainer: L.Trainer | None = None,
|
|
303
|
+
) -> bool:
|
|
304
|
+
now = dt.datetime.now(dt.timezone.utc)
|
|
305
|
+
if self.last_triggered is None or now - self.last_triggered >= self.interval:
|
|
306
|
+
self.last_triggered = now
|
|
307
|
+
return True
|
|
308
|
+
return False
|