dream-trainer 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dream_trainer-0.1.0/.github/workflows/docs.yaml +28 -0
- dream_trainer-0.1.0/.github/workflows/pypi.yaml +34 -0
- dream_trainer-0.1.0/.gitignore +23 -0
- dream_trainer-0.1.0/.python-version +1 -0
- dream_trainer-0.1.0/LICENSE +28 -0
- dream_trainer-0.1.0/PKG-INFO +21 -0
- dream_trainer-0.1.0/README.md +0 -0
- dream_trainer-0.1.0/docs/callbacks.md +457 -0
- dream_trainer-0.1.0/docs/configuration.md +358 -0
- dream_trainer-0.1.0/docs/getting-started.md +321 -0
- dream_trainer-0.1.0/docs/index.md +105 -0
- dream_trainer-0.1.0/docs/parallelism.md +422 -0
- dream_trainer-0.1.0/docs/trainer-guide.md +388 -0
- dream_trainer-0.1.0/mkdocs.yml +64 -0
- dream_trainer-0.1.0/pyproject.toml +57 -0
- dream_trainer-0.1.0/src/dream_trainer/__init__.py +6 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/__init__.py +42 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/callback.py +364 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/__init__.py +3 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/async.py +102 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/base.py +195 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/types.py +33 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/checkpoint/utils.py +53 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/fp8.py +127 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/ft.py +53 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/__init__.py +19 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/base.py +81 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/dist.py +6 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/media.py +127 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/metric.py +38 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/loggers/wandb_watch.py +33 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/progress_bar.py +179 -0
- dream_trainer-0.1.0/src/dream_trainer/callbacks/trainer_summary.py +120 -0
- dream_trainer-0.1.0/src/dream_trainer/configs/__init__.py +12 -0
- dream_trainer-0.1.0/src/dream_trainer/configs/checkpoint.py +46 -0
- dream_trainer-0.1.0/src/dream_trainer/configs/logger.py +9 -0
- dream_trainer-0.1.0/src/dream_trainer/configs/trainer.py +282 -0
- dream_trainer-0.1.0/src/dream_trainer/py.typed +0 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/__init__.py +12 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/abstract.py +116 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/base.py +462 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/dream.py +50 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/__init__.py +41 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/eval_metric.py +47 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/loggers/__init__.py +14 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/loggers/types.py +85 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/loggers/wandb.py +200 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/quantize.py +117 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/__init__.py +15 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/dataloader.py +59 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/models.py +274 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/optimizers.py +85 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/mixins/setup/setup.py +34 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/world/__init__.py +11 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/world/distributed_world.py +543 -0
- dream_trainer-0.1.0/src/dream_trainer/trainer/world/fault_tolerant_world.py +100 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/__init__.py +6 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/_logger.py +75 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/common.py +151 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/dataloader.py +220 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/entrypoint.py +80 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/logging.py +36 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/materialize.py +61 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/names.py +110 -0
- dream_trainer-0.1.0/src/dream_trainer/utils/serialize.py +57 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
name: docs
|
|
2
|
+
on:
|
|
3
|
+
push:
|
|
4
|
+
branches:
|
|
5
|
+
- main
|
|
6
|
+
permissions:
|
|
7
|
+
contents: write
|
|
8
|
+
jobs:
|
|
9
|
+
deploy:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
steps:
|
|
12
|
+
- uses: actions/checkout@v4
|
|
13
|
+
- name: Configure Git Credentials
|
|
14
|
+
run: |
|
|
15
|
+
git config user.name github-actions[bot]
|
|
16
|
+
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
|
17
|
+
- uses: actions/setup-python@v5
|
|
18
|
+
with:
|
|
19
|
+
python-version: 3.x
|
|
20
|
+
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
|
21
|
+
- uses: actions/cache@v4
|
|
22
|
+
with:
|
|
23
|
+
key: mkdocs-material-${{ env.cache_id }}
|
|
24
|
+
path: .cache
|
|
25
|
+
restore-keys: |
|
|
26
|
+
mkdocs-material-
|
|
27
|
+
- run: pip install mkdocs-material "jinja2>=3.1.3" 'mkdocstrings[python]'
|
|
28
|
+
- run: mkdocs gh-deploy --force
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# publish.yml
|
|
2
|
+
|
|
3
|
+
name: "Publish"
|
|
4
|
+
|
|
5
|
+
on:
|
|
6
|
+
release:
|
|
7
|
+
types: ["published"]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
build:
|
|
11
|
+
name: continuous-integration
|
|
12
|
+
runs-on: ubuntu-latest
|
|
13
|
+
strategy:
|
|
14
|
+
matrix:
|
|
15
|
+
python-version:
|
|
16
|
+
- "3.10"
|
|
17
|
+
- "3.11"
|
|
18
|
+
- "3.12"
|
|
19
|
+
|
|
20
|
+
steps:
|
|
21
|
+
- uses: actions/checkout@v4
|
|
22
|
+
|
|
23
|
+
- name: Install uv and set the Python version
|
|
24
|
+
uses: astral-sh/setup-uv@v5
|
|
25
|
+
with:
|
|
26
|
+
python-version: ${{ matrix.python-version }}
|
|
27
|
+
enable-cache: true
|
|
28
|
+
cache-dependency-glob: uv.lock
|
|
29
|
+
|
|
30
|
+
- name: Build
|
|
31
|
+
run: uv build
|
|
32
|
+
|
|
33
|
+
- name: Publish
|
|
34
|
+
run: uv publish -t ${{ secrets.PYPI_TOKEN }}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Python-generated files
|
|
2
|
+
*/__pycache__/
|
|
3
|
+
__pycache__/
|
|
4
|
+
*.py[oc]
|
|
5
|
+
.ruff_cache
|
|
6
|
+
|
|
7
|
+
# models
|
|
8
|
+
**.bin
|
|
9
|
+
**.pkl
|
|
10
|
+
**.pt
|
|
11
|
+
**.pth
|
|
12
|
+
**.safetensors
|
|
13
|
+
**.pkl
|
|
14
|
+
|
|
15
|
+
build/
|
|
16
|
+
dist/
|
|
17
|
+
wheels/
|
|
18
|
+
*.egg-info
|
|
19
|
+
|
|
20
|
+
# Virtual environments
|
|
21
|
+
.venv
|
|
22
|
+
**/.venv/**
|
|
23
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.12
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
(c) Dream3D, Inc. and affiliates.
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without modification,
|
|
6
|
+
are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice,this list
|
|
9
|
+
of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice, this
|
|
12
|
+
list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its contributors may
|
|
16
|
+
be used to endorse or promote products derived from this software without specific
|
|
17
|
+
prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
|
|
20
|
+
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
|
21
|
+
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
|
|
22
|
+
SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
|
23
|
+
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
|
24
|
+
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
|
|
25
|
+
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
26
|
+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
|
|
27
|
+
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
|
28
|
+
DAMAGE.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dream-trainer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author-email: Tony Francis <tony@dream3d.com>, Vikaas Varma <vik@dream3d.com>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Requires-Dist: dist-util==0.1.0
|
|
9
|
+
Requires-Dist: loguru>=0.7.3
|
|
10
|
+
Requires-Dist: torch>=2.7.0
|
|
11
|
+
Requires-Dist: tqdm>=4.67.1
|
|
12
|
+
Provides-Extra: metrics
|
|
13
|
+
Requires-Dist: torchmetrics>=1.7.1; extra == 'metrics'
|
|
14
|
+
Provides-Extra: rich
|
|
15
|
+
Requires-Dist: rich>=14.0.0; extra == 'rich'
|
|
16
|
+
Provides-Extra: torchao
|
|
17
|
+
Requires-Dist: torchao>=0.11.0; extra == 'torchao'
|
|
18
|
+
Provides-Extra: torchft
|
|
19
|
+
Requires-Dist: torchft; extra == 'torchft'
|
|
20
|
+
Provides-Extra: wandb
|
|
21
|
+
Requires-Dist: wandb[media]>=0.19.11; extra == 'wandb'
|
|
File without changes
|
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
# Callbacks Guide
|
|
2
|
+
|
|
3
|
+
This guide explains how to use and create callbacks in Dream Trainer.
|
|
4
|
+
|
|
5
|
+
## Table of Contents
|
|
6
|
+
- [Basic Usage](#basic-usage)
|
|
7
|
+
- [Built-in Callbacks](#built-in-callbacks)
|
|
8
|
+
- [Creating Callbacks](#creating-callbacks)
|
|
9
|
+
- [Callback Collection](#callback-collection)
|
|
10
|
+
- [Best Practices](#best-practices)
|
|
11
|
+
|
|
12
|
+
## Basic Usage
|
|
13
|
+
|
|
14
|
+
Callbacks are a way to extend the trainer's functionality without modifying its code. They are called at specific points during training.
|
|
15
|
+
|
|
16
|
+
### Adding Callbacks
|
|
17
|
+
|
|
18
|
+
Add callbacks to your trainer configuration:
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
from dream_trainer import DreamTrainerConfig
|
|
22
|
+
from dream_trainer.callbacks import (
|
|
23
|
+
LoggerCallback,
|
|
24
|
+
ProgressBar,
|
|
25
|
+
CallbackCollection
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
config = DreamTrainerConfig(
|
|
29
|
+
# ... other settings ...
|
|
30
|
+
callbacks=CallbackCollection([
|
|
31
|
+
LoggerCallback(), # Logs metrics to console/WandB
|
|
32
|
+
ProgressBar(), # Shows training progress
|
|
33
|
+
])
|
|
34
|
+
)
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
### Callback Order
|
|
38
|
+
|
|
39
|
+
Callbacks are executed in the order they are added. You can control the order:
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
callbacks = CallbackCollection([
|
|
43
|
+
LoggerCallback(), # First: log metrics
|
|
44
|
+
ProgressBar(), # Second: show progress
|
|
45
|
+
CheckpointCallback() # Third: save checkpoints
|
|
46
|
+
])
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Built-in Callbacks
|
|
50
|
+
|
|
51
|
+
### LoggerCallback
|
|
52
|
+
|
|
53
|
+
Logs metrics to console and/or WandB:
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
from dream_trainer.callbacks import LoggerCallback
|
|
57
|
+
|
|
58
|
+
logger = LoggerCallback(
|
|
59
|
+
log_every_n_steps=100, # Log every 100 steps
|
|
60
|
+
log_every_n_epochs=1, # Log every epoch
|
|
61
|
+
log_metrics=True, # Log metrics
|
|
62
|
+
log_gradients=False, # Don't log gradients
|
|
63
|
+
log_parameters=False # Don't log parameters
|
|
64
|
+
)
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
### ProgressBar
|
|
68
|
+
|
|
69
|
+
Shows training progress:
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
from dream_trainer.callbacks import ProgressBar
|
|
73
|
+
|
|
74
|
+
progress = ProgressBar(
|
|
75
|
+
refresh_rate=10, # Update every 10 steps
|
|
76
|
+
show_epoch=True, # Show epoch number
|
|
77
|
+
show_step=True, # Show step number
|
|
78
|
+
show_metrics=True # Show metrics
|
|
79
|
+
)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### CheckpointCallback
|
|
83
|
+
|
|
84
|
+
Saves model checkpoints:
|
|
85
|
+
|
|
86
|
+
```python
|
|
87
|
+
from dream_trainer.callbacks import CheckpointCallback
|
|
88
|
+
|
|
89
|
+
checkpoint = CheckpointCallback(
|
|
90
|
+
monitor="val_loss", # Metric to monitor
|
|
91
|
+
mode="min", # Minimize metric
|
|
92
|
+
save_top_k=3, # Keep best 3 checkpoints
|
|
93
|
+
save_last=True, # Always save latest
|
|
94
|
+
every_n_epochs=1 # Save every epoch
|
|
95
|
+
)
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
### EarlyStoppingCallback
|
|
99
|
+
|
|
100
|
+
Stops training when metric stops improving:
|
|
101
|
+
|
|
102
|
+
```python
|
|
103
|
+
from dream_trainer.callbacks import EarlyStoppingCallback
|
|
104
|
+
|
|
105
|
+
early_stopping = EarlyStoppingCallback(
|
|
106
|
+
monitor="val_loss", # Metric to monitor
|
|
107
|
+
mode="min", # Minimize metric
|
|
108
|
+
patience=5, # Wait 5 epochs
|
|
109
|
+
min_delta=0.001 # Minimum change
|
|
110
|
+
)
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### LearningRateMonitor
|
|
114
|
+
|
|
115
|
+
Logs learning rate changes:
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
from dream_trainer.callbacks import LearningRateMonitor
|
|
119
|
+
|
|
120
|
+
lr_monitor = LearningRateMonitor(
|
|
121
|
+
logging_interval="step", # Log every step
|
|
122
|
+
log_momentum=True # Log momentum too
|
|
123
|
+
)
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
## Creating Callbacks
|
|
127
|
+
|
|
128
|
+
### Basic Callback
|
|
129
|
+
|
|
130
|
+
Create a custom callback by extending `Callback`:
|
|
131
|
+
|
|
132
|
+
```python
|
|
133
|
+
from dream_trainer.callbacks import Callback
|
|
134
|
+
|
|
135
|
+
class MyCallback(Callback):
|
|
136
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
137
|
+
"""Called after each training batch"""
|
|
138
|
+
# Access trainer state
|
|
139
|
+
current_epoch = trainer.current_epoch
|
|
140
|
+
current_step = trainer.current_step
|
|
141
|
+
|
|
142
|
+
# Access outputs
|
|
143
|
+
loss = outputs["loss"]
|
|
144
|
+
|
|
145
|
+
# Do something
|
|
146
|
+
if loss > 10.0:
|
|
147
|
+
print(f"High loss detected: {loss}")
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
### Training Hooks
|
|
151
|
+
|
|
152
|
+
Available training hooks:
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
class MyCallback(Callback):
|
|
156
|
+
def on_train_start(self, trainer):
|
|
157
|
+
"""Called when training starts"""
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
def on_train_epoch_start(self, trainer):
|
|
161
|
+
"""Called at the start of each training epoch"""
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
def on_train_batch_start(self, trainer, batch, batch_idx):
|
|
165
|
+
"""Called before each training batch"""
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
169
|
+
"""Called after each training batch"""
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
def on_train_epoch_end(self, trainer):
|
|
173
|
+
"""Called at the end of each training epoch"""
|
|
174
|
+
pass
|
|
175
|
+
|
|
176
|
+
def on_train_end(self, trainer):
|
|
177
|
+
"""Called when training ends"""
|
|
178
|
+
pass
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
### Validation Hooks
|
|
182
|
+
|
|
183
|
+
Available validation hooks:
|
|
184
|
+
|
|
185
|
+
```python
|
|
186
|
+
class MyCallback(Callback):
|
|
187
|
+
def on_validation_start(self, trainer):
|
|
188
|
+
"""Called when validation starts"""
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
def on_validation_epoch_start(self, trainer):
|
|
192
|
+
"""Called at the start of each validation epoch"""
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
def on_validation_batch_start(self, trainer, batch, batch_idx):
|
|
196
|
+
"""Called before each validation batch"""
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
def on_validation_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
200
|
+
"""Called after each validation batch"""
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
def on_validation_epoch_end(self, trainer):
|
|
204
|
+
"""Called at the end of each validation epoch"""
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
def on_validation_end(self, trainer):
|
|
208
|
+
"""Called when validation ends"""
|
|
209
|
+
pass
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
### State Management
|
|
213
|
+
|
|
214
|
+
Callbacks can maintain their own state:
|
|
215
|
+
|
|
216
|
+
```python
|
|
217
|
+
class StatefulCallback(Callback):
|
|
218
|
+
def __init__(self):
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.best_metric = float('inf')
|
|
221
|
+
self.patience_counter = 0
|
|
222
|
+
|
|
223
|
+
def on_validation_epoch_end(self, trainer):
|
|
224
|
+
# Get current metric
|
|
225
|
+
current_metric = trainer.get_metric("val_loss")
|
|
226
|
+
|
|
227
|
+
# Update state
|
|
228
|
+
if current_metric < self.best_metric:
|
|
229
|
+
self.best_metric = current_metric
|
|
230
|
+
self.patience_counter = 0
|
|
231
|
+
else:
|
|
232
|
+
self.patience_counter += 1
|
|
233
|
+
|
|
234
|
+
# Check patience
|
|
235
|
+
if self.patience_counter >= 5:
|
|
236
|
+
trainer.should_stop = True
|
|
237
|
+
```
|
|
238
|
+
|
|
239
|
+
### Accessing Trainer
|
|
240
|
+
|
|
241
|
+
Callbacks have access to the trainer instance:
|
|
242
|
+
|
|
243
|
+
```python
|
|
244
|
+
class TrainerAwareCallback(Callback):
|
|
245
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
246
|
+
# Access trainer attributes
|
|
247
|
+
model = trainer.model
|
|
248
|
+
optimizer = trainer.optimizer
|
|
249
|
+
current_epoch = trainer.current_epoch
|
|
250
|
+
|
|
251
|
+
# Access trainer methods
|
|
252
|
+
trainer.log("custom_metric", 42)
|
|
253
|
+
trainer.save_checkpoint("path/to/checkpoint.pt")
|
|
254
|
+
```
|
|
255
|
+
|
|
256
|
+
## Callback Collection
|
|
257
|
+
|
|
258
|
+
### Adding Callbacks
|
|
259
|
+
|
|
260
|
+
Add callbacks to a collection:
|
|
261
|
+
|
|
262
|
+
```python
|
|
263
|
+
from dream_trainer.callbacks import CallbackCollection
|
|
264
|
+
|
|
265
|
+
callbacks = CallbackCollection([
|
|
266
|
+
LoggerCallback(),
|
|
267
|
+
ProgressBar(),
|
|
268
|
+
MyCustomCallback()
|
|
269
|
+
])
|
|
270
|
+
```
|
|
271
|
+
|
|
272
|
+
### Removing Callbacks
|
|
273
|
+
|
|
274
|
+
Remove callbacks from a collection:
|
|
275
|
+
|
|
276
|
+
```python
|
|
277
|
+
# Remove by type
|
|
278
|
+
callbacks.remove(LoggerCallback)
|
|
279
|
+
|
|
280
|
+
# Remove by instance
|
|
281
|
+
callbacks.remove(my_callback)
|
|
282
|
+
```
|
|
283
|
+
|
|
284
|
+
### Reordering Callbacks
|
|
285
|
+
|
|
286
|
+
Change callback order:
|
|
287
|
+
|
|
288
|
+
```python
|
|
289
|
+
# Move to front
|
|
290
|
+
callbacks.move_to_front(my_callback)
|
|
291
|
+
|
|
292
|
+
# Move to back
|
|
293
|
+
callbacks.move_to_back(my_callback)
|
|
294
|
+
|
|
295
|
+
# Move to specific position
|
|
296
|
+
callbacks.move_to_position(my_callback, 2)
|
|
297
|
+
```
|
|
298
|
+
|
|
299
|
+
## Best Practices
|
|
300
|
+
|
|
301
|
+
### 1. Keep Callbacks Focused
|
|
302
|
+
|
|
303
|
+
Each callback should do one thing well:
|
|
304
|
+
|
|
305
|
+
```python
|
|
306
|
+
# Good: Single responsibility
|
|
307
|
+
class LossMonitor(Callback):
|
|
308
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
309
|
+
if outputs["loss"] > 10.0:
|
|
310
|
+
print("High loss detected")
|
|
311
|
+
|
|
312
|
+
# Bad: Multiple responsibilities
|
|
313
|
+
class BadCallback(Callback):
|
|
314
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
315
|
+
# Monitoring
|
|
316
|
+
if outputs["loss"] > 10.0:
|
|
317
|
+
print("High loss detected")
|
|
318
|
+
# Logging
|
|
319
|
+
trainer.log("custom_metric", 42)
|
|
320
|
+
# Checkpointing
|
|
321
|
+
trainer.save_checkpoint("checkpoint.pt")
|
|
322
|
+
```
|
|
323
|
+
|
|
324
|
+
### 2. Use Type Hints
|
|
325
|
+
|
|
326
|
+
Add type hints for better IDE support:
|
|
327
|
+
|
|
328
|
+
```python
|
|
329
|
+
from typing import Dict, Any
|
|
330
|
+
import torch
|
|
331
|
+
|
|
332
|
+
class TypedCallback(Callback):
|
|
333
|
+
def on_train_batch_end(
|
|
334
|
+
self,
|
|
335
|
+
trainer: "DreamTrainer",
|
|
336
|
+
outputs: Dict[str, torch.Tensor],
|
|
337
|
+
batch: torch.Tensor,
|
|
338
|
+
batch_idx: int
|
|
339
|
+
) -> None:
|
|
340
|
+
pass
|
|
341
|
+
```
|
|
342
|
+
|
|
343
|
+
### 3. Document Callbacks
|
|
344
|
+
|
|
345
|
+
Add docstrings to explain functionality:
|
|
346
|
+
|
|
347
|
+
```python
|
|
348
|
+
class DocumentedCallback(Callback):
|
|
349
|
+
"""Monitors training metrics and logs warnings.
|
|
350
|
+
|
|
351
|
+
This callback watches for:
|
|
352
|
+
- High loss values
|
|
353
|
+
- NaN gradients
|
|
354
|
+
- Learning rate spikes
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
loss_threshold: Threshold for high loss warning
|
|
358
|
+
lr_threshold: Threshold for learning rate warning
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(self, loss_threshold: float = 10.0, lr_threshold: float = 1e-2):
|
|
362
|
+
super().__init__()
|
|
363
|
+
self.loss_threshold = loss_threshold
|
|
364
|
+
self.lr_threshold = lr_threshold
|
|
365
|
+
```
|
|
366
|
+
|
|
367
|
+
### 4. Handle Errors
|
|
368
|
+
|
|
369
|
+
Add proper error handling:
|
|
370
|
+
|
|
371
|
+
```python
|
|
372
|
+
class ErrorHandlingCallback(Callback):
|
|
373
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
374
|
+
try:
|
|
375
|
+
# Risky operation
|
|
376
|
+
self.process_outputs(outputs)
|
|
377
|
+
except Exception as e:
|
|
378
|
+
# Log error but don't crash
|
|
379
|
+
trainer.log("callback_error", str(e))
|
|
380
|
+
```
|
|
381
|
+
|
|
382
|
+
### 5. Test Callbacks
|
|
383
|
+
|
|
384
|
+
Write unit tests for your callbacks:
|
|
385
|
+
|
|
386
|
+
```python
|
|
387
|
+
def test_my_callback():
|
|
388
|
+
# Create mock trainer
|
|
389
|
+
trainer = MockTrainer()
|
|
390
|
+
|
|
391
|
+
# Create callback
|
|
392
|
+
callback = MyCallback()
|
|
393
|
+
|
|
394
|
+
# Test hook
|
|
395
|
+
callback.on_train_batch_end(
|
|
396
|
+
trainer,
|
|
397
|
+
outputs={"loss": torch.tensor(5.0)},
|
|
398
|
+
batch=torch.randn(32, 10),
|
|
399
|
+
batch_idx=0
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Assert expected behavior
|
|
403
|
+
assert trainer.logged_metrics["custom_metric"] == 42
|
|
404
|
+
```
|
|
405
|
+
|
|
406
|
+
### 6. Use Callback Priority
|
|
407
|
+
|
|
408
|
+
Set callback priority for execution order:
|
|
409
|
+
|
|
410
|
+
```python
|
|
411
|
+
class HighPriorityCallback(Callback):
|
|
412
|
+
priority = 100 # Higher number = earlier execution
|
|
413
|
+
|
|
414
|
+
class LowPriorityCallback(Callback):
|
|
415
|
+
priority = 0 # Lower number = later execution
|
|
416
|
+
```
|
|
417
|
+
|
|
418
|
+
### 7. Avoid Side Effects
|
|
419
|
+
|
|
420
|
+
Minimize side effects in callbacks:
|
|
421
|
+
|
|
422
|
+
```python
|
|
423
|
+
class CleanCallback(Callback):
|
|
424
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
425
|
+
# Good: Only logging
|
|
426
|
+
trainer.log("metric", outputs["loss"])
|
|
427
|
+
|
|
428
|
+
# Bad: Modifying trainer state
|
|
429
|
+
trainer.model.requires_grad_(False) # Don't do this
|
|
430
|
+
```
|
|
431
|
+
|
|
432
|
+
### 8. Use Callback Groups
|
|
433
|
+
|
|
434
|
+
Group related callbacks:
|
|
435
|
+
|
|
436
|
+
```python
|
|
437
|
+
class MonitoringGroup(Callback):
|
|
438
|
+
"""Group of monitoring callbacks"""
|
|
439
|
+
|
|
440
|
+
def __init__(self):
|
|
441
|
+
super().__init__()
|
|
442
|
+
self.callbacks = [
|
|
443
|
+
LossMonitor(),
|
|
444
|
+
GradientMonitor(),
|
|
445
|
+
LearningRateMonitor()
|
|
446
|
+
]
|
|
447
|
+
|
|
448
|
+
def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
|
|
449
|
+
for callback in self.callbacks:
|
|
450
|
+
callback.on_train_batch_end(trainer, outputs, batch, batch_idx)
|
|
451
|
+
```
|
|
452
|
+
|
|
453
|
+
## Next Steps
|
|
454
|
+
|
|
455
|
+
- Explore [Examples](examples.md) to see callbacks in action
|
|
456
|
+
- Read about [Distributed Training](distributed.md) for multi-GPU callback considerations
|
|
457
|
+
- Check the [API Reference](api-reference.md) for detailed callback documentation
|