fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,284 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any, Protocol, TYPE_CHECKING
4
+ from typing_extensions import override
5
+
6
+ import lightning as L
7
+ from lightning.pytorch.utilities import rank_zero_only
8
+ from mlflow.entities import Metric, RunTag, Param
9
+ from mlflow.tracking import MlflowClient # type: ignore[possibly-unbound-import]
10
+
11
+ if TYPE_CHECKING:
12
+ from lightning.pytorch.loggers import MLFlowLogger, TensorBoardLogger, WandbLogger
13
+
14
+ from fkat.utils import assert_not_none
15
+
16
+
17
+ def _is_logger_type(logger: Any, logger_name: str) -> bool:
18
+ """Check if logger matches type from lightning or pytorch_lightning."""
19
+ module = type(logger).__module__
20
+ return type(logger).__name__ == logger_name and (
21
+ module.startswith("lightning.pytorch.loggers") or module.startswith("pytorch_lightning.loggers")
22
+ )
23
+
24
+
25
+ class LightningLogger(Protocol):
26
+ """
27
+ Protocol defining the interface for logging that handle metrics, tags, and artifacts.
28
+ """
29
+
30
+ def tags(self) -> dict[str, Any]:
31
+ """Get current tags"""
32
+ ...
33
+
34
+ def log_tag(self, key: str, value: str) -> None:
35
+ """
36
+ Log a single key-value tag.
37
+
38
+ Args:
39
+ key (str): The identifier/name of the tag
40
+ value (str): The value associated with the tag
41
+ """
42
+ ...
43
+
44
+ def log_batch(
45
+ self,
46
+ metrics: dict[str, float] | None = None,
47
+ params: dict[str, Any] | None = None,
48
+ tags: dict[str, str] | None = None,
49
+ timestamp: int | None = None,
50
+ step: int | None = None,
51
+ ) -> None:
52
+ """
53
+ Log multiple metrics and/or tags in a single batch operation.
54
+
55
+ Args:
56
+ metrics (dict[str, float], optional): Dictionary mapping metric names to their float values
57
+ params (dict[str, Any], optional): Dictionary mapping params names to their values
58
+ tags (dict[str, str], optional): Dictionary mapping tag names to their string values
59
+ timestamp (int, optional): Unix timestamp for when the batch was logged
60
+ step (int, optional): Training step or iteration number
61
+ """
62
+ ...
63
+
64
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
65
+ """
66
+ Log a local file as an artifact.
67
+
68
+ Args:
69
+ local_path (str): Path to the file on the local filesystem to be logged
70
+ artifact_path (str, optional): Remote path where the artifact should be stored
71
+ If None, a default location should be used
72
+ """
73
+ ...
74
+
75
+
76
+ class MLFlowLogger:
77
+ """
78
+ Mlflow logger class that supports rank_zero logging of tags, metrics and distributed
79
+ logging of artifacts.
80
+
81
+ Args:
82
+ logger (MLFlowLogger): PTL MLFlow logger object
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ logger: "MLFlowLogger | None" = None,
88
+ client: MlflowClient | None = None,
89
+ synchronous: bool | None = None,
90
+ run_id: str | None = None,
91
+ ) -> None:
92
+ super().__init__()
93
+ if logger:
94
+ self._client: MlflowClient = assert_not_none(getattr(logger, "_mlflow_client", None))
95
+ self._synchronous = getattr(logger, "_log_batch_kwargs", {}).get("synchronous")
96
+ self._run_id: str = assert_not_none(getattr(logger, "_run_id", None))
97
+ else:
98
+ assert client
99
+ self._client = client
100
+ self._synchronous = synchronous
101
+ assert run_id
102
+ self._run_id = run_id
103
+
104
+ @rank_zero_only
105
+ def log_tag(self, key: str, value: str) -> None:
106
+ self._client.set_tag(run_id=self._run_id, key=key, value=value, synchronous=self._synchronous)
107
+
108
+ def tags(self) -> dict[str, Any]:
109
+ run = self._client.get_run(self._run_id)
110
+ return run.data.tags
111
+
112
+ @rank_zero_only
113
+ def log_batch(
114
+ self,
115
+ metrics: dict[str, float] | None = None,
116
+ params: dict[str, Any] | None = None,
117
+ tags: dict[str, str] | None = None,
118
+ timestamp: int | None = None,
119
+ step: int | None = None,
120
+ ) -> None:
121
+ ms = [Metric(k, v, timestamp, step) for k, v in metrics.items()] if metrics else []
122
+ ps = [Param(k, v) for k, v in params.items()] if params else []
123
+ ts = [RunTag(k, v) for k, v in tags.items()] if tags else []
124
+ self._client.log_batch(
125
+ run_id=self._run_id,
126
+ metrics=ms,
127
+ params=ps,
128
+ tags=ts,
129
+ synchronous=self._synchronous,
130
+ )
131
+
132
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
133
+ # TODO: log directly to s3 uri
134
+ # TODO: support async logging
135
+ self._client.log_artifact(
136
+ run_id=self._run_id,
137
+ local_path=local_path,
138
+ artifact_path=artifact_path,
139
+ )
140
+
141
+
142
+ class TensorBoardLogger(LightningLogger):
143
+ """TensorBoard logger with rank_zero logging."""
144
+
145
+ def __init__(self, logger: "TensorBoardLogger") -> None:
146
+ self._logger = logger
147
+
148
+ @rank_zero_only
149
+ def log_tag(self, key: str, value: str) -> None:
150
+ self._logger.experiment.add_text(key, value)
151
+
152
+ def tags(self) -> dict[str, Any]:
153
+ return {}
154
+
155
+ @rank_zero_only
156
+ def log_batch(
157
+ self,
158
+ metrics: dict[str, float] | None = None,
159
+ params: dict[str, Any] | None = None,
160
+ tags: dict[str, str] | None = None,
161
+ timestamp: int | None = None,
162
+ step: int | None = None,
163
+ ) -> None:
164
+ if metrics:
165
+ for k, v in metrics.items():
166
+ self._logger.experiment.add_scalar(k, v, step)
167
+ if tags:
168
+ for k, v in tags.items():
169
+ self._logger.experiment.add_text(k, v, step)
170
+
171
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
172
+ from pathlib import Path
173
+ import shutil
174
+
175
+ dest = Path(self._logger.log_dir) / (artifact_path or Path(local_path).name)
176
+ dest.parent.mkdir(parents=True, exist_ok=True)
177
+ shutil.copy2(local_path, dest)
178
+
179
+
180
+ class WandbLogger(LightningLogger):
181
+ """WandB logger with rank_zero logging."""
182
+
183
+ def __init__(self, logger: "WandbLogger") -> None:
184
+ self._logger = logger
185
+
186
+ @rank_zero_only
187
+ def log_tag(self, key: str, value: str) -> None:
188
+ self._logger.experiment.config.update({key: value})
189
+
190
+ def tags(self) -> dict[str, Any]:
191
+ return dict(self._logger.experiment.config)
192
+
193
+ @rank_zero_only
194
+ def log_batch(
195
+ self,
196
+ metrics: dict[str, float] | None = None,
197
+ params: dict[str, Any] | None = None,
198
+ tags: dict[str, str] | None = None,
199
+ timestamp: int | None = None,
200
+ step: int | None = None,
201
+ ) -> None:
202
+ log_dict = {}
203
+ if metrics:
204
+ log_dict.update(metrics)
205
+ if tags:
206
+ log_dict.update(tags)
207
+ if log_dict:
208
+ self._logger.experiment.log(log_dict, step=step)
209
+
210
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
211
+ self._logger.experiment.save(local_path)
212
+
213
+ if self._logger.experiment.settings.mode == "offline":
214
+ import shutil
215
+ from pathlib import Path
216
+
217
+ src = Path(local_path).absolute()
218
+ dest = Path(self._logger.experiment.settings.files_dir) / src.name
219
+
220
+ if dest.is_symlink():
221
+ dest.unlink()
222
+ shutil.copy2(src, dest)
223
+
224
+
225
+ class CompositeLogger(LightningLogger):
226
+ """
227
+ A wrapper on top of the collection of :class:`Logger` instances,
228
+ providing methods to log metrics, artifacts, and tags across all registered loggers
229
+ simultaneously.
230
+
231
+ Attributes:
232
+ loggers (list[LightningLogger]): List of loggers
233
+
234
+ Args:
235
+ trainer (L.Trainer): PyTorch Lightning trainer instance used to initialize loggers
236
+ """
237
+
238
+ loggers: list[LightningLogger]
239
+
240
+ def __init__(self, trainer: "L.Trainer | None", loggers: list[LightningLogger] | None = None) -> None:
241
+ if trainer:
242
+ self.loggers = []
243
+ for logger in trainer.loggers:
244
+ if _is_logger_type(logger, "MLFlowLogger"):
245
+ self.loggers.append(MLFlowLogger(logger=logger)) # type: ignore[arg-type]
246
+ elif _is_logger_type(logger, "TensorBoardLogger"):
247
+ self.loggers.append(TensorBoardLogger(logger=logger)) # type: ignore[arg-type]
248
+ elif _is_logger_type(logger, "WandbLogger"):
249
+ self.loggers.append(WandbLogger(logger=logger)) # type: ignore[arg-type]
250
+ else:
251
+ assert loggers
252
+ self.loggers = loggers
253
+
254
+ def __str__(self) -> str:
255
+ return str([type(obj).__name__ for obj in self.loggers])
256
+
257
+ @override
258
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
259
+ for logger in self.loggers:
260
+ logger.log_artifact(local_path=local_path, artifact_path=artifact_path)
261
+
262
+ @override
263
+ def log_batch(
264
+ self,
265
+ metrics: dict[str, float] | None = None,
266
+ params: dict[str, Any] | None = None,
267
+ tags: dict[str, str] | None = None,
268
+ timestamp: int | None = None,
269
+ step: int | None = None,
270
+ ) -> None:
271
+ for logger in self.loggers:
272
+ logger.log_batch(metrics=metrics, tags=tags, timestamp=timestamp, step=step)
273
+
274
+ @override
275
+ def tags(self) -> dict[str, Any]:
276
+ tags = {}
277
+ for logger in self.loggers:
278
+ tags.update(logger.tags())
279
+ return tags
280
+
281
+ @override
282
+ def log_tag(self, key: str, value: str) -> None:
283
+ for logger in self.loggers:
284
+ logger.log_tag(key=key, value=value)
@@ -0,0 +1,27 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .base import (
4
+ Schedule,
5
+ Never,
6
+ Always,
7
+ Fixed,
8
+ Every,
9
+ Elapsed,
10
+ GlobalRank,
11
+ LocalRank,
12
+ InvertedSchedule,
13
+ CombinedSchedule,
14
+ )
15
+
16
+ __all__ = [
17
+ "Schedule",
18
+ "Never",
19
+ "Always",
20
+ "Fixed",
21
+ "Every",
22
+ "Elapsed",
23
+ "GlobalRank",
24
+ "LocalRank",
25
+ "InvertedSchedule",
26
+ "CombinedSchedule",
27
+ ]
@@ -0,0 +1,308 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import datetime as dt
4
+ import operator
5
+ from collections.abc import Callable, Sequence
6
+ from functools import reduce
7
+ from typing import Protocol
8
+ from typing_extensions import override
9
+
10
+ import lightning as L
11
+
12
+
13
+ class Schedule(Protocol):
14
+ """
15
+ Protocol defining a generic PyTorch schedule
16
+ """
17
+
18
+ def check(
19
+ self,
20
+ *,
21
+ stage: str | None = None,
22
+ batch_idx: int | None = None,
23
+ step: int | None = None,
24
+ trainer: L.Trainer | None = None,
25
+ ) -> bool:
26
+ """Checks the schedule for the given moment.
27
+
28
+ Args:
29
+ stage (str): current trainer stage. eg. train/test/validate/predict
30
+ batch_idx (Optional[int]): current batch_idx
31
+ step (Optional[int]): current step
32
+ trainer (Optional[Trainer]): lightning trainer of callback
33
+ Returns:
34
+ bool: True if schedule passed the check, False otherwise
35
+ """
36
+ ...
37
+
38
+ def __and__(self, first: "Schedule", second: "Schedule") -> "CombinedSchedule":
39
+ return self._combine(operator.and_, first, second)
40
+
41
+ def __or__(self, first: "Schedule", second: "Schedule") -> "CombinedSchedule":
42
+ return self._combine(operator.or_, first, second)
43
+
44
+ def __invert__(self, other: "Schedule") -> "InvertedSchedule":
45
+ return InvertedSchedule(other)
46
+
47
+ def _combine(self, fn: Callable[[bool, bool], bool], first: "Schedule", second: "Schedule") -> "CombinedSchedule":
48
+ return CombinedSchedule(fn, (first, second))
49
+
50
+
51
+ class InvertedSchedule(Schedule):
52
+ def __init__(self, other: Schedule) -> None:
53
+ self.other = other
54
+
55
+ @override
56
+ def check(
57
+ self,
58
+ *,
59
+ stage: str | None = None,
60
+ batch_idx: int | None = None,
61
+ step: int | None = None,
62
+ trainer: L.Trainer | None = None,
63
+ ) -> bool:
64
+ return not self.other.check(stage=stage, batch_idx=batch_idx, step=step, trainer=trainer)
65
+
66
+
67
+ class CombinedSchedule(Schedule):
68
+ def __init__(self, fn: Callable[[bool, bool], bool], schedules: Sequence[Schedule]) -> None:
69
+ self.fn = fn
70
+ self.schedules = schedules
71
+
72
+ @override
73
+ def check(
74
+ self,
75
+ *,
76
+ stage: str | None = None,
77
+ batch_idx: int | None = None,
78
+ step: int | None = None,
79
+ trainer: L.Trainer | None = None,
80
+ ) -> bool:
81
+ return reduce(
82
+ self.fn, (s.check(stage=stage, batch_idx=batch_idx, step=step, trainer=trainer) for s in self.schedules)
83
+ )
84
+
85
+
86
+ class GlobalRank(Schedule):
87
+ """
88
+ A schedule that only executes on specific global ranks in a distributed training setup.
89
+
90
+ This is useful for operations that should only be performed on certain ranks,
91
+ such as logging, checkpointing, or other operations that would be redundant
92
+ or conflicting if performed on all ranks.
93
+
94
+ Attributes:
95
+ ranks (tuple[int, ...]): The global ranks on which this schedule should execute.
96
+ """
97
+
98
+ def __init__(self, *ranks: int) -> None:
99
+ """
100
+ Initialize a GlobalRank schedule.
101
+
102
+ Args:
103
+ *ranks: Variable number of integer ranks. The schedule will only execute
104
+ on these global ranks.
105
+ """
106
+ self.ranks = ranks
107
+
108
+ @override
109
+ def check(
110
+ self,
111
+ *,
112
+ stage: str | None = None,
113
+ batch_idx: int | None = None,
114
+ step: int | None = None,
115
+ trainer: L.Trainer | None = None,
116
+ ) -> bool:
117
+ """
118
+ Check if the current global rank is in the specified ranks.
119
+
120
+ Args:
121
+ stage: Current trainer stage (ignored by this schedule).
122
+ batch_idx: Current batch index (ignored by this schedule).
123
+ step: Current step (ignored by this schedule).
124
+ trainer: Lightning trainer instance, used to get the global rank.
125
+
126
+ Returns:
127
+ bool: True if the trainer's global rank is in the specified ranks, False otherwise.
128
+ Always returns False if trainer is None.
129
+ """
130
+ if trainer is None:
131
+ return False
132
+ return trainer.global_rank in self.ranks
133
+
134
+
135
+ class LocalRank(Schedule):
136
+ """
137
+ A schedule that only executes on specific local ranks in a distributed training setup.
138
+
139
+ This is useful for node-specific operations that should only be performed on certain
140
+ ranks within each node, such as local logging or monitoring.
141
+
142
+ Attributes:
143
+ ranks (tuple[int, ...]): The local ranks on which this schedule should execute.
144
+ """
145
+
146
+ def __init__(self, *ranks: int) -> None:
147
+ """
148
+ Initialize a LocalRank schedule.
149
+
150
+ Args:
151
+ *ranks: Variable number of integer ranks. The schedule will only execute
152
+ on these local ranks.
153
+ """
154
+ self.ranks = ranks
155
+
156
+ @override
157
+ def check(
158
+ self,
159
+ *,
160
+ stage: str | None = None,
161
+ batch_idx: int | None = None,
162
+ step: int | None = None,
163
+ trainer: L.Trainer | None = None,
164
+ ) -> bool:
165
+ """
166
+ Check if the current local rank is in the specified ranks.
167
+
168
+ Args:
169
+ stage: Current trainer stage (ignored by this schedule).
170
+ batch_idx: Current batch index (ignored by this schedule).
171
+ step: Current step (ignored by this schedule).
172
+ trainer: Lightning trainer instance, used to get the local rank.
173
+
174
+ Returns:
175
+ bool: True if the trainer's local rank is in the specified ranks, False otherwise.
176
+ Always returns False if trainer is None.
177
+ """
178
+ if trainer is None:
179
+ return False
180
+ return trainer.local_rank in self.ranks
181
+
182
+
183
+ class Never(Schedule):
184
+ """
185
+ A schedule for an event that never happens.
186
+ """
187
+
188
+ @override
189
+ def check(
190
+ self,
191
+ *,
192
+ stage: str | None = None,
193
+ batch_idx: int | None = None,
194
+ step: int | None = None,
195
+ trainer: L.Trainer | None = None,
196
+ ) -> bool:
197
+ return False
198
+
199
+
200
+ class Always(Schedule):
201
+ """
202
+ A schedule for an event that always happens.
203
+ """
204
+
205
+ @override
206
+ def check(
207
+ self,
208
+ *,
209
+ stage: str | None = None,
210
+ batch_idx: int | None = None,
211
+ step: int | None = None,
212
+ trainer: L.Trainer | None = None,
213
+ ) -> bool:
214
+ return True
215
+
216
+
217
+ class Fixed(Schedule):
218
+ """
219
+ A schedule for an event that happens after a warmup period
220
+ and lasts for a fixed number of steps.
221
+
222
+ Attributes:
223
+ warmup_steps (int): Number of initial steps to skip before logging starts.
224
+ active_steps (int): Number of steps to log after warmup period.
225
+ """
226
+
227
+ def __init__(self, warmup_steps: int, active_steps: int) -> None:
228
+ self._warmup_steps: int = warmup_steps
229
+ self._active_steps: int = active_steps
230
+
231
+ @override
232
+ def check(
233
+ self,
234
+ *,
235
+ stage: str | None = None,
236
+ batch_idx: int | None = None,
237
+ step: int | None = None,
238
+ trainer: L.Trainer | None = None,
239
+ ) -> bool:
240
+ assert step is not None, "step must be provided"
241
+ if step < self._warmup_steps:
242
+ return False
243
+
244
+ if step - self._warmup_steps >= self._active_steps:
245
+ return False
246
+ return True
247
+
248
+
249
+ class Every(Schedule):
250
+ """
251
+ A schedule for an event that happens every specified number of batches and/or steps.
252
+
253
+ Attributes:
254
+ n_batches (Optional[int]): A positive number of batches between logging events.
255
+ Defaults to 0 - use only n_steps
256
+ n_steps (Optional[int]): A positive number of (train) steps between logging events.
257
+ Defaults to 0 - use only n_batches
258
+ stage (Optional[str]): The stage this schedule applies to ('train', 'validation', 'test', 'predict').
259
+ If None, applies to all stages.
260
+ """
261
+
262
+ def __init__(self, *, n_batches: int = 0, n_steps: int = 0, stage: str | None = None) -> None:
263
+ assert n_batches or n_steps, "either n_batches or n_steps has to be a positive number"
264
+ self._n_batches = n_batches
265
+ self._n_steps = n_steps
266
+ self._stage = stage
267
+
268
+ @override
269
+ def check(
270
+ self,
271
+ *,
272
+ stage: str | None = None,
273
+ batch_idx: int | None = None,
274
+ step: int | None = None,
275
+ trainer: L.Trainer | None = None,
276
+ ) -> bool:
277
+ # If stage is specified and doesn't match, return False
278
+ if self._stage is not None and stage != self._stage:
279
+ return False
280
+
281
+ return (batch_idx is not None and self._n_batches > 0 and batch_idx % self._n_batches == 0) or (
282
+ step is not None and self._n_steps > 0 and step % self._n_steps == 0
283
+ )
284
+
285
+
286
+ class Elapsed(Schedule):
287
+ """
288
+ A schedule for an event that happens after the provided time interval has elapsed.
289
+ """
290
+
291
+ def __init__(self, interval: dt.timedelta) -> None:
292
+ self.interval = interval
293
+ self.last_triggered: dt.datetime | None = None
294
+
295
+ @override
296
+ def check(
297
+ self,
298
+ *,
299
+ stage: str | None = None,
300
+ batch_idx: int | None = None,
301
+ step: int | None = None,
302
+ trainer: L.Trainer | None = None,
303
+ ) -> bool:
304
+ now = dt.datetime.now(dt.timezone.utc)
305
+ if self.last_triggered is None or now - self.last_triggered >= self.interval:
306
+ self.last_triggered = now
307
+ return True
308
+ return False