dream-trainer 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. dream_trainer-0.1.0/.github/workflows/docs.yaml +28 -0
  2. dream_trainer-0.1.0/.github/workflows/pypi.yaml +34 -0
  3. dream_trainer-0.1.0/.gitignore +23 -0
  4. dream_trainer-0.1.0/.python-version +1 -0
  5. dream_trainer-0.1.0/LICENSE +28 -0
  6. dream_trainer-0.1.0/PKG-INFO +21 -0
  7. dream_trainer-0.1.0/README.md +0 -0
  8. dream_trainer-0.1.0/docs/callbacks.md +457 -0
  9. dream_trainer-0.1.0/docs/configuration.md +358 -0
  10. dream_trainer-0.1.0/docs/getting-started.md +321 -0
  11. dream_trainer-0.1.0/docs/index.md +105 -0
  12. dream_trainer-0.1.0/docs/parallelism.md +422 -0
  13. dream_trainer-0.1.0/docs/trainer-guide.md +388 -0
  14. dream_trainer-0.1.0/mkdocs.yml +64 -0
  15. dream_trainer-0.1.0/pyproject.toml +57 -0
  16. dream_trainer-0.1.0/src/dream_trainer/__init__.py +6 -0
  17. dream_trainer-0.1.0/src/dream_trainer/callbacks/__init__.py +42 -0
  18. dream_trainer-0.1.0/src/dream_trainer/callbacks/callback.py +364 -0
  19. dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/__init__.py +3 -0
  20. dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/async.py +102 -0
  21. dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/base.py +195 -0
  22. dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/types.py +33 -0
  23. dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/utils.py +53 -0
  24. dream_trainer-0.1.0/src/dream_trainer/callbacks/fp8.py +127 -0
  25. dream_trainer-0.1.0/src/dream_trainer/callbacks/ft.py +53 -0
  26. dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/__init__.py +19 -0
  27. dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/base.py +81 -0
  28. dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/dist.py +6 -0
  29. dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/media.py +127 -0
  30. dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/metric.py +38 -0
  31. dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/wandb_watch.py +33 -0
  32. dream_trainer-0.1.0/src/dream_trainer/callbacks/progress_bar.py +179 -0
  33. dream_trainer-0.1.0/src/dream_trainer/callbacks/trainer_summary.py +120 -0
  34. dream_trainer-0.1.0/src/dream_trainer/configs/__init__.py +12 -0
  35. dream_trainer-0.1.0/src/dream_trainer/configs/checkpoint.py +46 -0
  36. dream_trainer-0.1.0/src/dream_trainer/configs/logger.py +9 -0
  37. dream_trainer-0.1.0/src/dream_trainer/configs/trainer.py +282 -0
  38. dream_trainer-0.1.0/src/dream_trainer/py.typed +0 -0
  39. dream_trainer-0.1.0/src/dream_trainer/trainer/__init__.py +12 -0
  40. dream_trainer-0.1.0/src/dream_trainer/trainer/abstract.py +116 -0
  41. dream_trainer-0.1.0/src/dream_trainer/trainer/base.py +462 -0
  42. dream_trainer-0.1.0/src/dream_trainer/trainer/dream.py +50 -0
  43. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/__init__.py +41 -0
  44. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/eval_metric.py +47 -0
  45. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/loggers/__init__.py +14 -0
  46. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/loggers/types.py +85 -0
  47. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/loggers/wandb.py +200 -0
  48. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/quantize.py +117 -0
  49. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/__init__.py +15 -0
  50. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/dataloader.py +59 -0
  51. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/models.py +274 -0
  52. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/optimizers.py +85 -0
  53. dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/setup.py +34 -0
  54. dream_trainer-0.1.0/src/dream_trainer/trainer/world/__init__.py +11 -0
  55. dream_trainer-0.1.0/src/dream_trainer/trainer/world/distributed_world.py +543 -0
  56. dream_trainer-0.1.0/src/dream_trainer/trainer/world/fault_tolerant_world.py +100 -0
  57. dream_trainer-0.1.0/src/dream_trainer/utils/__init__.py +6 -0
  58. dream_trainer-0.1.0/src/dream_trainer/utils/_logger.py +75 -0
  59. dream_trainer-0.1.0/src/dream_trainer/utils/common.py +151 -0
  60. dream_trainer-0.1.0/src/dream_trainer/utils/dataloader.py +220 -0
  61. dream_trainer-0.1.0/src/dream_trainer/utils/entrypoint.py +80 -0
  62. dream_trainer-0.1.0/src/dream_trainer/utils/logging.py +36 -0
  63. dream_trainer-0.1.0/src/dream_trainer/utils/materialize.py +61 -0
  64. dream_trainer-0.1.0/src/dream_trainer/utils/names.py +110 -0
  65. dream_trainer-0.1.0/src/dream_trainer/utils/serialize.py +57 -0
@@ -0,0 +1,28 @@
1
+ name: docs
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ permissions:
7
+ contents: write
8
+ jobs:
9
+ deploy:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v4
13
+ - name: Configure Git Credentials
14
+ run: |
15
+ git config user.name github-actions[bot]
16
+ git config user.email 41898282+github-actions[bot]@users.noreply.github.com
17
+ - uses: actions/setup-python@v5
18
+ with:
19
+ python-version: 3.x
20
+ - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
21
+ - uses: actions/cache@v4
22
+ with:
23
+ key: mkdocs-material-${{ env.cache_id }}
24
+ path: .cache
25
+ restore-keys: |
26
+ mkdocs-material-
27
+ - run: pip install mkdocs-material "jinja2>=3.1.3" 'mkdocstrings[python]'
28
+ - run: mkdocs gh-deploy --force
@@ -0,0 +1,34 @@
1
+ # publish.yml
2
+
3
+ name: "Publish"
4
+
5
+ on:
6
+ release:
7
+ types: ["published"]
8
+
9
+ jobs:
10
+ build:
11
+ name: continuous-integration
12
+ runs-on: ubuntu-latest
13
+ strategy:
14
+ matrix:
15
+ python-version:
16
+ - "3.10"
17
+ - "3.11"
18
+ - "3.12"
19
+
20
+ steps:
21
+ - uses: actions/checkout@v4
22
+
23
+ - name: Install uv and set the Python version
24
+ uses: astral-sh/setup-uv@v5
25
+ with:
26
+ python-version: ${{ matrix.python-version }}
27
+ enable-cache: true
28
+ cache-dependency-glob: uv.lock
29
+
30
+ - name: Build
31
+ run: uv build
32
+
33
+ - name: Publish
34
+ run: uv publish -t ${{ secrets.PYPI_TOKEN }}
@@ -0,0 +1,23 @@
1
+ # Python-generated files
2
+ */__pycache__/
3
+ __pycache__/
4
+ *.py[oc]
5
+ .ruff_cache
6
+
7
+ # models
8
+ **.bin
9
+ **.pkl
10
+ **.pt
11
+ **.pth
12
+ **.safetensors
13
+ **.pkl
14
+
15
+ build/
16
+ dist/
17
+ wheels/
18
+ *.egg-info
19
+
20
+ # Virtual environments
21
+ .venv
22
+ **/.venv/**
23
+
@@ -0,0 +1 @@
1
+ 3.12
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ (c) Dream3D, Inc. and affiliates.
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice,this list
9
+ of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice, this
12
+ list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors may
16
+ be used to endorse or promote products derived from this software without specific
17
+ prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
20
+ EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21
+ OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
22
+ SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
23
+ INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
24
+ TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
25
+ BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
27
+ ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
28
+ DAMAGE.
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: dream-trainer
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Author-email: Tony Francis <tony@dream3d.com>, Vikaas Varma <vik@dream3d.com>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.10
8
+ Requires-Dist: dist-util==0.1.0
9
+ Requires-Dist: loguru>=0.7.3
10
+ Requires-Dist: torch>=2.7.0
11
+ Requires-Dist: tqdm>=4.67.1
12
+ Provides-Extra: metrics
13
+ Requires-Dist: torchmetrics>=1.7.1; extra == 'metrics'
14
+ Provides-Extra: rich
15
+ Requires-Dist: rich>=14.0.0; extra == 'rich'
16
+ Provides-Extra: torchao
17
+ Requires-Dist: torchao>=0.11.0; extra == 'torchao'
18
+ Provides-Extra: torchft
19
+ Requires-Dist: torchft; extra == 'torchft'
20
+ Provides-Extra: wandb
21
+ Requires-Dist: wandb[media]>=0.19.11; extra == 'wandb'
File without changes
@@ -0,0 +1,457 @@
1
+ # Callbacks Guide
2
+
3
+ This guide explains how to use and create callbacks in Dream Trainer.
4
+
5
+ ## Table of Contents
6
+ - [Basic Usage](#basic-usage)
7
+ - [Built-in Callbacks](#built-in-callbacks)
8
+ - [Creating Callbacks](#creating-callbacks)
9
+ - [Callback Collection](#callback-collection)
10
+ - [Best Practices](#best-practices)
11
+
12
+ ## Basic Usage
13
+
14
+ Callbacks are a way to extend the trainer's functionality without modifying its code. They are called at specific points during training.
15
+
16
+ ### Adding Callbacks
17
+
18
+ Add callbacks to your trainer configuration:
19
+
20
+ ```python
21
+ from dream_trainer import DreamTrainerConfig
22
+ from dream_trainer.callbacks import (
23
+ LoggerCallback,
24
+ ProgressBar,
25
+ CallbackCollection
26
+ )
27
+
28
+ config = DreamTrainerConfig(
29
+ # ... other settings ...
30
+ callbacks=CallbackCollection([
31
+ LoggerCallback(), # Logs metrics to console/WandB
32
+ ProgressBar(), # Shows training progress
33
+ ])
34
+ )
35
+ ```
36
+
37
+ ### Callback Order
38
+
39
+ Callbacks are executed in the order they are added. You can control the order:
40
+
41
+ ```python
42
+ callbacks = CallbackCollection([
43
+ LoggerCallback(), # First: log metrics
44
+ ProgressBar(), # Second: show progress
45
+ CheckpointCallback() # Third: save checkpoints
46
+ ])
47
+ ```
48
+
49
+ ## Built-in Callbacks
50
+
51
+ ### LoggerCallback
52
+
53
+ Logs metrics to console and/or WandB:
54
+
55
+ ```python
56
+ from dream_trainer.callbacks import LoggerCallback
57
+
58
+ logger = LoggerCallback(
59
+ log_every_n_steps=100, # Log every 100 steps
60
+ log_every_n_epochs=1, # Log every epoch
61
+ log_metrics=True, # Log metrics
62
+ log_gradients=False, # Don't log gradients
63
+ log_parameters=False # Don't log parameters
64
+ )
65
+ ```
66
+
67
+ ### ProgressBar
68
+
69
+ Shows training progress:
70
+
71
+ ```python
72
+ from dream_trainer.callbacks import ProgressBar
73
+
74
+ progress = ProgressBar(
75
+ refresh_rate=10, # Update every 10 steps
76
+ show_epoch=True, # Show epoch number
77
+ show_step=True, # Show step number
78
+ show_metrics=True # Show metrics
79
+ )
80
+ ```
81
+
82
+ ### CheckpointCallback
83
+
84
+ Saves model checkpoints:
85
+
86
+ ```python
87
+ from dream_trainer.callbacks import CheckpointCallback
88
+
89
+ checkpoint = CheckpointCallback(
90
+ monitor="val_loss", # Metric to monitor
91
+ mode="min", # Minimize metric
92
+ save_top_k=3, # Keep best 3 checkpoints
93
+ save_last=True, # Always save latest
94
+ every_n_epochs=1 # Save every epoch
95
+ )
96
+ ```
97
+
98
+ ### EarlyStoppingCallback
99
+
100
+ Stops training when metric stops improving:
101
+
102
+ ```python
103
+ from dream_trainer.callbacks import EarlyStoppingCallback
104
+
105
+ early_stopping = EarlyStoppingCallback(
106
+ monitor="val_loss", # Metric to monitor
107
+ mode="min", # Minimize metric
108
+ patience=5, # Wait 5 epochs
109
+ min_delta=0.001 # Minimum change
110
+ )
111
+ ```
112
+
113
+ ### LearningRateMonitor
114
+
115
+ Logs learning rate changes:
116
+
117
+ ```python
118
+ from dream_trainer.callbacks import LearningRateMonitor
119
+
120
+ lr_monitor = LearningRateMonitor(
121
+ logging_interval="step", # Log every step
122
+ log_momentum=True # Log momentum too
123
+ )
124
+ ```
125
+
126
+ ## Creating Callbacks
127
+
128
+ ### Basic Callback
129
+
130
+ Create a custom callback by extending `Callback`:
131
+
132
+ ```python
133
+ from dream_trainer.callbacks import Callback
134
+
135
+ class MyCallback(Callback):
136
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
137
+ """Called after each training batch"""
138
+ # Access trainer state
139
+ current_epoch = trainer.current_epoch
140
+ current_step = trainer.current_step
141
+
142
+ # Access outputs
143
+ loss = outputs["loss"]
144
+
145
+ # Do something
146
+ if loss > 10.0:
147
+ print(f"High loss detected: {loss}")
148
+ ```
149
+
150
+ ### Training Hooks
151
+
152
+ Available training hooks:
153
+
154
+ ```python
155
+ class MyCallback(Callback):
156
+ def on_train_start(self, trainer):
157
+ """Called when training starts"""
158
+ pass
159
+
160
+ def on_train_epoch_start(self, trainer):
161
+ """Called at the start of each training epoch"""
162
+ pass
163
+
164
+ def on_train_batch_start(self, trainer, batch, batch_idx):
165
+ """Called before each training batch"""
166
+ pass
167
+
168
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
169
+ """Called after each training batch"""
170
+ pass
171
+
172
+ def on_train_epoch_end(self, trainer):
173
+ """Called at the end of each training epoch"""
174
+ pass
175
+
176
+ def on_train_end(self, trainer):
177
+ """Called when training ends"""
178
+ pass
179
+ ```
180
+
181
+ ### Validation Hooks
182
+
183
+ Available validation hooks:
184
+
185
+ ```python
186
+ class MyCallback(Callback):
187
+ def on_validation_start(self, trainer):
188
+ """Called when validation starts"""
189
+ pass
190
+
191
+ def on_validation_epoch_start(self, trainer):
192
+ """Called at the start of each validation epoch"""
193
+ pass
194
+
195
+ def on_validation_batch_start(self, trainer, batch, batch_idx):
196
+ """Called before each validation batch"""
197
+ pass
198
+
199
+ def on_validation_batch_end(self, trainer, outputs, batch, batch_idx):
200
+ """Called after each validation batch"""
201
+ pass
202
+
203
+ def on_validation_epoch_end(self, trainer):
204
+ """Called at the end of each validation epoch"""
205
+ pass
206
+
207
+ def on_validation_end(self, trainer):
208
+ """Called when validation ends"""
209
+ pass
210
+ ```
211
+
212
+ ### State Management
213
+
214
+ Callbacks can maintain their own state:
215
+
216
+ ```python
217
+ class StatefulCallback(Callback):
218
+ def __init__(self):
219
+ super().__init__()
220
+ self.best_metric = float('inf')
221
+ self.patience_counter = 0
222
+
223
+ def on_validation_epoch_end(self, trainer):
224
+ # Get current metric
225
+ current_metric = trainer.get_metric("val_loss")
226
+
227
+ # Update state
228
+ if current_metric < self.best_metric:
229
+ self.best_metric = current_metric
230
+ self.patience_counter = 0
231
+ else:
232
+ self.patience_counter += 1
233
+
234
+ # Check patience
235
+ if self.patience_counter >= 5:
236
+ trainer.should_stop = True
237
+ ```
238
+
239
+ ### Accessing Trainer
240
+
241
+ Callbacks have access to the trainer instance:
242
+
243
+ ```python
244
+ class TrainerAwareCallback(Callback):
245
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
246
+ # Access trainer attributes
247
+ model = trainer.model
248
+ optimizer = trainer.optimizer
249
+ current_epoch = trainer.current_epoch
250
+
251
+ # Access trainer methods
252
+ trainer.log("custom_metric", 42)
253
+ trainer.save_checkpoint("path/to/checkpoint.pt")
254
+ ```
255
+
256
+ ## Callback Collection
257
+
258
+ ### Adding Callbacks
259
+
260
+ Add callbacks to a collection:
261
+
262
+ ```python
263
+ from dream_trainer.callbacks import CallbackCollection
264
+
265
+ callbacks = CallbackCollection([
266
+ LoggerCallback(),
267
+ ProgressBar(),
268
+ MyCustomCallback()
269
+ ])
270
+ ```
271
+
272
+ ### Removing Callbacks
273
+
274
+ Remove callbacks from a collection:
275
+
276
+ ```python
277
+ # Remove by type
278
+ callbacks.remove(LoggerCallback)
279
+
280
+ # Remove by instance
281
+ callbacks.remove(my_callback)
282
+ ```
283
+
284
+ ### Reordering Callbacks
285
+
286
+ Change callback order:
287
+
288
+ ```python
289
+ # Move to front
290
+ callbacks.move_to_front(my_callback)
291
+
292
+ # Move to back
293
+ callbacks.move_to_back(my_callback)
294
+
295
+ # Move to specific position
296
+ callbacks.move_to_position(my_callback, 2)
297
+ ```
298
+
299
+ ## Best Practices
300
+
301
+ ### 1. Keep Callbacks Focused
302
+
303
+ Each callback should do one thing well:
304
+
305
+ ```python
306
+ # Good: Single responsibility
307
+ class LossMonitor(Callback):
308
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
309
+ if outputs["loss"] > 10.0:
310
+ print("High loss detected")
311
+
312
+ # Bad: Multiple responsibilities
313
+ class BadCallback(Callback):
314
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
315
+ # Monitoring
316
+ if outputs["loss"] > 10.0:
317
+ print("High loss detected")
318
+ # Logging
319
+ trainer.log("custom_metric", 42)
320
+ # Checkpointing
321
+ trainer.save_checkpoint("checkpoint.pt")
322
+ ```
323
+
324
+ ### 2. Use Type Hints
325
+
326
+ Add type hints for better IDE support:
327
+
328
+ ```python
329
+ from typing import Dict, Any
330
+ import torch
331
+
332
+ class TypedCallback(Callback):
333
+ def on_train_batch_end(
334
+ self,
335
+ trainer: "DreamTrainer",
336
+ outputs: Dict[str, torch.Tensor],
337
+ batch: torch.Tensor,
338
+ batch_idx: int
339
+ ) -> None:
340
+ pass
341
+ ```
342
+
343
+ ### 3. Document Callbacks
344
+
345
+ Add docstrings to explain functionality:
346
+
347
+ ```python
348
+ class DocumentedCallback(Callback):
349
+ """Monitors training metrics and logs warnings.
350
+
351
+ This callback watches for:
352
+ - High loss values
353
+ - NaN gradients
354
+ - Learning rate spikes
355
+
356
+ Args:
357
+ loss_threshold: Threshold for high loss warning
358
+ lr_threshold: Threshold for learning rate warning
359
+ """
360
+
361
+ def __init__(self, loss_threshold: float = 10.0, lr_threshold: float = 1e-2):
362
+ super().__init__()
363
+ self.loss_threshold = loss_threshold
364
+ self.lr_threshold = lr_threshold
365
+ ```
366
+
367
+ ### 4. Handle Errors
368
+
369
+ Add proper error handling:
370
+
371
+ ```python
372
+ class ErrorHandlingCallback(Callback):
373
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
374
+ try:
375
+ # Risky operation
376
+ self.process_outputs(outputs)
377
+ except Exception as e:
378
+ # Log error but don't crash
379
+ trainer.log("callback_error", str(e))
380
+ ```
381
+
382
+ ### 5. Test Callbacks
383
+
384
+ Write unit tests for your callbacks:
385
+
386
+ ```python
387
+ def test_my_callback():
388
+ # Create mock trainer
389
+ trainer = MockTrainer()
390
+
391
+ # Create callback
392
+ callback = MyCallback()
393
+
394
+ # Test hook
395
+ callback.on_train_batch_end(
396
+ trainer,
397
+ outputs={"loss": torch.tensor(5.0)},
398
+ batch=torch.randn(32, 10),
399
+ batch_idx=0
400
+ )
401
+
402
+ # Assert expected behavior
403
+ assert trainer.logged_metrics["custom_metric"] == 42
404
+ ```
405
+
406
+ ### 6. Use Callback Priority
407
+
408
+ Set callback priority for execution order:
409
+
410
+ ```python
411
+ class HighPriorityCallback(Callback):
412
+ priority = 100 # Higher number = earlier execution
413
+
414
+ class LowPriorityCallback(Callback):
415
+ priority = 0 # Lower number = later execution
416
+ ```
417
+
418
+ ### 7. Avoid Side Effects
419
+
420
+ Minimize side effects in callbacks:
421
+
422
+ ```python
423
+ class CleanCallback(Callback):
424
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
425
+ # Good: Only logging
426
+ trainer.log("metric", outputs["loss"])
427
+
428
+ # Bad: Modifying trainer state
429
+ trainer.model.requires_grad_(False) # Don't do this
430
+ ```
431
+
432
+ ### 8. Use Callback Groups
433
+
434
+ Group related callbacks:
435
+
436
+ ```python
437
+ class MonitoringGroup(Callback):
438
+ """Group of monitoring callbacks"""
439
+
440
+ def __init__(self):
441
+ super().__init__()
442
+ self.callbacks = [
443
+ LossMonitor(),
444
+ GradientMonitor(),
445
+ LearningRateMonitor()
446
+ ]
447
+
448
+ def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
449
+ for callback in self.callbacks:
450
+ callback.on_train_batch_end(trainer, outputs, batch, batch_idx)
451
+ ```
452
+
453
+ ## Next Steps
454
+
455
+ - Explore [Examples](examples.md) to see callbacks in action
456
+ - Read about [Distributed Training](distributed.md) for multi-GPU callback considerations
457
+ - Check the [API Reference](api-reference.md) for detailed callback documentation