sarasa 0.0.2__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sarasa-0.0.2 → sarasa-0.0.4}/PKG-INFO +10 -4
- {sarasa-0.0.2 → sarasa-0.0.4}/README.md +9 -3
- {sarasa-0.0.2 → sarasa-0.0.4}/configs/example.py +5 -2
- {sarasa-0.0.2 → sarasa-0.0.4}/configs/llama3-1b.py +4 -3
- sarasa-0.0.4/sarasa/__init__.py +11 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/checkpoint.py +10 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/config.py +21 -4
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/__init__.py +2 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/attention.py +6 -2
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/llama3.py +18 -7
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/nanochat_gpt.py +10 -1
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/utils.py +0 -9
- sarasa-0.0.2/sarasa/trainer.py → sarasa-0.0.4/sarasa/train.py +65 -39
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/utils.py +8 -4
- {sarasa-0.0.2 → sarasa-0.0.4}/tests/test_config.py +7 -0
- sarasa-0.0.2/sarasa/__init__.py +0 -2
- {sarasa-0.0.2 → sarasa-0.0.4}/.github/workflows/pypi.yaml +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/.github/workflows/tests_and_lint.yaml +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/.gitignore +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/LICENSE +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/main.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/pyproject.toml +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/activation_checkpoint.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/data/__init__.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/data/hf_datasets.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/data/tokenizer.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/metrics.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/optimizers/__init__.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/optimizers/utils.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/tests/test_model.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.4}/tests/test_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sarasa
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: Add your description here
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Python: >=3.13
|
|
@@ -46,6 +46,8 @@ uv add sarasa[cpu|cu128|cu130]
|
|
|
46
46
|
- Async distributed checkpoint saving
|
|
47
47
|
|
|
48
48
|
- [ ] Checkpoint loading
|
|
49
|
+
- [ ] FP8 training
|
|
50
|
+
- [ ] Profiling
|
|
49
51
|
|
|
50
52
|
## Usage
|
|
51
53
|
|
|
@@ -100,18 +102,22 @@ if __name__ == "__main__":
|
|
|
100
102
|
trainer.train()
|
|
101
103
|
```
|
|
102
104
|
|
|
105
|
+
Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
|
|
103
106
|
From the command line, you can specify which custom optimizer to use:
|
|
104
107
|
|
|
105
108
|
```bash
|
|
106
109
|
python script.py optim:custom_optim --optim.lr 0.001 ...
|
|
107
110
|
```
|
|
108
111
|
|
|
112
|
+
(As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
|
|
113
|
+
|
|
109
114
|
### Config File Example
|
|
110
115
|
|
|
111
|
-
It's very simple.
|
|
116
|
+
It's very simple.
|
|
117
|
+
IDE autocompletion will help you.
|
|
112
118
|
|
|
113
119
|
```python
|
|
114
|
-
from sarasa
|
|
120
|
+
from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
|
|
115
121
|
from custom_optim import CustomOptim
|
|
116
122
|
|
|
117
123
|
# only one Config instance should be defined in each config file
|
|
@@ -135,4 +141,4 @@ config = Config.create(
|
|
|
135
141
|
|
|
136
142
|
## Acknowledgements
|
|
137
143
|
|
|
138
|
-
This project is heavily inspired by and borrows code from
|
|
144
|
+
This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
|
|
@@ -23,6 +23,8 @@ uv add sarasa[cpu|cu128|cu130]
|
|
|
23
23
|
- Async distributed checkpoint saving
|
|
24
24
|
|
|
25
25
|
- [ ] Checkpoint loading
|
|
26
|
+
- [ ] FP8 training
|
|
27
|
+
- [ ] Profiling
|
|
26
28
|
|
|
27
29
|
## Usage
|
|
28
30
|
|
|
@@ -77,18 +79,22 @@ if __name__ == "__main__":
|
|
|
77
79
|
trainer.train()
|
|
78
80
|
```
|
|
79
81
|
|
|
82
|
+
Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
|
|
80
83
|
From the command line, you can specify which custom optimizer to use:
|
|
81
84
|
|
|
82
85
|
```bash
|
|
83
86
|
python script.py optim:custom_optim --optim.lr 0.001 ...
|
|
84
87
|
```
|
|
85
88
|
|
|
89
|
+
(As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
|
|
90
|
+
|
|
86
91
|
### Config File Example
|
|
87
92
|
|
|
88
|
-
It's very simple.
|
|
93
|
+
It's very simple.
|
|
94
|
+
IDE autocompletion will help you.
|
|
89
95
|
|
|
90
96
|
```python
|
|
91
|
-
from sarasa
|
|
97
|
+
from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
|
|
92
98
|
from custom_optim import CustomOptim
|
|
93
99
|
|
|
94
100
|
# only one Config instance should be defined in each config file
|
|
@@ -112,4 +118,4 @@ config = Config.create(
|
|
|
112
118
|
|
|
113
119
|
## Acknowledgements
|
|
114
120
|
|
|
115
|
-
This project is heavily inspired by and borrows code from
|
|
121
|
+
This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
|
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
from sarasa.config import AdamW, Config, Data, LRScheduler, Model, Train
|
|
2
2
|
|
|
3
3
|
config = Config.create(
|
|
4
|
-
model=Model(
|
|
4
|
+
model=Model(
|
|
5
|
+
name="nanochat_gpt",
|
|
6
|
+
num_layers=12,
|
|
7
|
+
qk_norm=True,
|
|
8
|
+
),
|
|
5
9
|
train=Train(
|
|
6
10
|
local_batch_size=16,
|
|
7
11
|
global_batch_size=256,
|
|
8
|
-
dtype="bfloat16",
|
|
9
12
|
),
|
|
10
13
|
data=Data(tokenizer_path="./tokenizer"),
|
|
11
14
|
lr_scheduler=LRScheduler(
|
|
@@ -2,17 +2,18 @@ from sarasa.config import FSDP, AdamW, Config, Data, LRScheduler, Model, Train
|
|
|
2
2
|
|
|
3
3
|
config = Config.create(
|
|
4
4
|
model=Model(
|
|
5
|
+
name="llama3",
|
|
5
6
|
hidden_dim=2048,
|
|
6
7
|
num_layers=16,
|
|
7
8
|
num_heads=32,
|
|
8
9
|
num_kv_heads=8,
|
|
9
10
|
head_dim=64,
|
|
10
|
-
|
|
11
|
+
rms_eps=1e-5,
|
|
12
|
+
rms_learnable=True,
|
|
11
13
|
),
|
|
12
14
|
train=Train(
|
|
13
15
|
local_batch_size=32,
|
|
14
|
-
global_batch_size=
|
|
15
|
-
dtype="bfloat16",
|
|
16
|
+
global_batch_size=1024,
|
|
16
17
|
use_sac=True,
|
|
17
18
|
),
|
|
18
19
|
data=Data(tokenizer_path="./tokenizer"),
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .config import DDP as DDP
|
|
2
|
+
from .config import FSDP as FSDP
|
|
3
|
+
from .config import AdamW as AdamW
|
|
4
|
+
from .config import Checkpoint as Checkpoint
|
|
5
|
+
from .config import Config as Config
|
|
6
|
+
from .config import Data as Data
|
|
7
|
+
from .config import LRScheduler as LRScheduler
|
|
8
|
+
from .config import Metrics as Metrics
|
|
9
|
+
from .config import Model as Model
|
|
10
|
+
from .config import Train as Train
|
|
11
|
+
from .train import Trainer as Trainer
|
|
@@ -110,3 +110,13 @@ class Checkpointer:
|
|
|
110
110
|
def close(self) -> None:
|
|
111
111
|
if self.stager is not None:
|
|
112
112
|
self.stager.close()
|
|
113
|
+
|
|
114
|
+
if self.save_future is not None:
|
|
115
|
+
self.save_future.result()
|
|
116
|
+
|
|
117
|
+
if self.pg is not None:
|
|
118
|
+
dist.destroy_process_group(self.pg)
|
|
119
|
+
self.pg = None
|
|
120
|
+
|
|
121
|
+
def __del__(self):
|
|
122
|
+
self.close()
|
|
@@ -91,6 +91,10 @@ class Train:
|
|
|
91
91
|
grad_clip: float | None = None
|
|
92
92
|
|
|
93
93
|
dtype: Literal["bfloat16", "float32"] = "float32"
|
|
94
|
+
"""Dtype used for model initialization"""
|
|
95
|
+
|
|
96
|
+
amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
|
|
97
|
+
"""Dtype used for automatic mixed precision training"""
|
|
94
98
|
|
|
95
99
|
compile: bool = False
|
|
96
100
|
|
|
@@ -154,6 +158,12 @@ class FSDP(Distributed):
|
|
|
154
158
|
reshard_after_forward: bool = False
|
|
155
159
|
"""Whether to reshard model parameters after each forward pass (FSDP only)."""
|
|
156
160
|
|
|
161
|
+
dtype: str | None = None
|
|
162
|
+
"""Dtype for FSDP reduce operations. If None, uses train.dtype."""
|
|
163
|
+
|
|
164
|
+
amp_dtype: str | None = None
|
|
165
|
+
"""Dtype for FSDP parameter storage. If None, uses train.amp_dtype."""
|
|
166
|
+
|
|
157
167
|
|
|
158
168
|
@dataclasses.dataclass
|
|
159
169
|
class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
@@ -183,11 +193,15 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
|
183
193
|
if self.output_dir is not None:
|
|
184
194
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
185
195
|
|
|
186
|
-
if hasattr(self.model, "seq_len")
|
|
187
|
-
if self.data.seq_len is not None:
|
|
196
|
+
if hasattr(self.model, "seq_len"):
|
|
197
|
+
if self.model.seq_len is None and self.data.seq_len is not None:
|
|
188
198
|
self.model.seq_len = self.data.seq_len
|
|
189
|
-
|
|
190
|
-
raise ValueError("
|
|
199
|
+
if self.model.seq_len is None:
|
|
200
|
+
raise ValueError("seq_len must be specified in either model or data configuration.")
|
|
201
|
+
|
|
202
|
+
if isinstance(self.distributed, FSDP):
|
|
203
|
+
self.distributed.dtype = self.distributed.dtype or self.train.dtype
|
|
204
|
+
self.distributed.amp_dtype = self.distributed.amp_dtype or self.train.amp_dtype
|
|
191
205
|
|
|
192
206
|
@classmethod
|
|
193
207
|
def create(
|
|
@@ -227,6 +241,8 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
|
227
241
|
|
|
228
242
|
import tyro
|
|
229
243
|
|
|
244
|
+
from sarasa.utils import rank
|
|
245
|
+
|
|
230
246
|
loaded_config = None
|
|
231
247
|
|
|
232
248
|
if (under := ("--config_file" in sys.argv)) or ("--config-file" in sys.argv):
|
|
@@ -262,6 +278,7 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
|
262
278
|
data_type,
|
|
263
279
|
],
|
|
264
280
|
default=loaded_config,
|
|
281
|
+
console_outputs=(rank() == 0),
|
|
265
282
|
)
|
|
266
283
|
|
|
267
284
|
|
|
@@ -21,6 +21,8 @@ class ModelConfig:
|
|
|
21
21
|
vocab_size: int | None = None # set later based on tokenizer
|
|
22
22
|
seq_len: int | None = None # set later based on data config
|
|
23
23
|
qk_norm: bool = False # whether to use RMSNorm on q/k
|
|
24
|
+
rms_eps: float | None = None # epsilon for RMSNorm, default to library default if None
|
|
25
|
+
rms_learnable: bool = False # whether RMSNorm has learnable scale parameter
|
|
24
26
|
|
|
25
27
|
def __post_init__(self):
|
|
26
28
|
# infer hidden_dim, num_heads, num_kv_heads if not provided using the rules presented in nanochat
|
|
@@ -3,7 +3,7 @@ from torch import nn
|
|
|
3
3
|
from torch.nn import functional as F
|
|
4
4
|
|
|
5
5
|
from sarasa.models import ModelConfig
|
|
6
|
-
from sarasa.models.utils import
|
|
6
|
+
from sarasa.models.utils import RoPE
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class SDPAttention(nn.Module):
|
|
@@ -57,7 +57,11 @@ class CausalSelfAttention(nn.Module):
|
|
|
57
57
|
self.c_k = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
58
58
|
self.c_v = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
59
59
|
self.c_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
|
|
60
|
-
self.qk_norm =
|
|
60
|
+
self.qk_norm = (
|
|
61
|
+
nn.RMSNorm(self.head_dim, eps=config.rms_eps, elementwise_affine=config.rms_learnable)
|
|
62
|
+
if config.qk_norm
|
|
63
|
+
else nn.Identity()
|
|
64
|
+
)
|
|
61
65
|
|
|
62
66
|
# todo: support varlen etc and kv caching
|
|
63
67
|
self.attn = SDPAttention(is_causal=True, enable_gqa=self.num_heads != self.num_kv_heads)
|
|
@@ -4,7 +4,7 @@ from torch.nn import functional as F
|
|
|
4
4
|
|
|
5
5
|
from sarasa.models import BaseModel, ModelConfig
|
|
6
6
|
from sarasa.models.attention import CausalSelfAttention
|
|
7
|
-
from sarasa.models.utils import
|
|
7
|
+
from sarasa.models.utils import RoPE
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class MLP(nn.Module):
|
|
@@ -41,15 +41,16 @@ class Block(nn.Module):
|
|
|
41
41
|
self.layer_idx = layer_idx
|
|
42
42
|
self.attention = CausalSelfAttention(config)
|
|
43
43
|
self.mlp = MLP(config, multiple_of, ffn_dim_multiplier)
|
|
44
|
-
self.
|
|
44
|
+
self.attn_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
|
|
45
|
+
self.mlp_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
|
|
45
46
|
|
|
46
47
|
def forward(
|
|
47
48
|
self,
|
|
48
49
|
x: torch.Tensor,
|
|
49
50
|
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
|
50
51
|
) -> torch.Tensor:
|
|
51
|
-
x = x + self.attention(self.
|
|
52
|
-
x = x + self.mlp(self.
|
|
52
|
+
x = x + self.attention(self.attn_norm(x), cos_sin)
|
|
53
|
+
x = x + self.mlp(self.mlp_norm(x))
|
|
53
54
|
return x
|
|
54
55
|
|
|
55
56
|
|
|
@@ -71,7 +72,7 @@ class Llama3(BaseModel):
|
|
|
71
72
|
self.blocks = nn.ModuleList([
|
|
72
73
|
Block(config, layer_idx, multiple_of, ffn_dim_multiplier) for layer_idx in range(config.num_layers)
|
|
73
74
|
])
|
|
74
|
-
self.norm = RMSNorm(config.hidden_dim)
|
|
75
|
+
self.norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
|
|
75
76
|
self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
|
76
77
|
|
|
77
78
|
@torch.no_grad()
|
|
@@ -101,16 +102,26 @@ class Llama3(BaseModel):
|
|
|
101
102
|
b=cutoff_factor * final_out_std,
|
|
102
103
|
)
|
|
103
104
|
|
|
105
|
+
for mod in self.modules():
|
|
106
|
+
if isinstance(mod, nn.RMSNorm):
|
|
107
|
+
mod.reset_parameters()
|
|
108
|
+
|
|
104
109
|
def param_groups(self) -> dict[str, list[nn.Parameter]]:
|
|
105
|
-
matrix_params =
|
|
110
|
+
matrix_params = [param for param in self.blocks.parameters() if param.ndim == 2]
|
|
106
111
|
embedding_params = list(self.token_emb.parameters())
|
|
107
112
|
lm_head_params = list(self.output.parameters())
|
|
108
|
-
|
|
113
|
+
rms_norm_params = [
|
|
114
|
+
mod.weight for mod in self.modules() if isinstance(mod, nn.RMSNorm) and mod.elementwise_affine
|
|
115
|
+
]
|
|
116
|
+
assert len(list(self.parameters())) == (
|
|
117
|
+
len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(rms_norm_params)
|
|
118
|
+
)
|
|
109
119
|
|
|
110
120
|
return {
|
|
111
121
|
"matrix": matrix_params,
|
|
112
122
|
"embedding": embedding_params,
|
|
113
123
|
"lm_head": lm_head_params,
|
|
124
|
+
"rms_norm": rms_norm_params,
|
|
114
125
|
}
|
|
115
126
|
|
|
116
127
|
def forward(
|
|
@@ -8,7 +8,16 @@ from torch.nn import functional as F
|
|
|
8
8
|
|
|
9
9
|
from sarasa.models import BaseModel, ModelConfig
|
|
10
10
|
from sarasa.models.attention import CausalSelfAttention
|
|
11
|
-
from sarasa.models.utils import
|
|
11
|
+
from sarasa.models.utils import RoPE
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RMSNorm(torch.nn.RMSNorm):
|
|
15
|
+
# RMSNorm without affine parameters
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
normalized_shape: int,
|
|
19
|
+
):
|
|
20
|
+
super().__init__(normalized_shape, eps=None, elementwise_affine=False)
|
|
12
21
|
|
|
13
22
|
|
|
14
23
|
class MLP(nn.Module):
|
|
@@ -1,15 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
class RMSNorm(torch.nn.RMSNorm):
|
|
5
|
-
# RMSNorm without affine parameters
|
|
6
|
-
def __init__(
|
|
7
|
-
self,
|
|
8
|
-
normalized_shape: int,
|
|
9
|
-
):
|
|
10
|
-
super().__init__(normalized_shape, eps=None, elementwise_affine=False)
|
|
11
|
-
|
|
12
|
-
|
|
13
4
|
class RoPE:
|
|
14
5
|
@staticmethod
|
|
15
6
|
def precompute(
|
|
@@ -7,6 +7,7 @@ import torch
|
|
|
7
7
|
import torch.distributed as dist
|
|
8
8
|
from loguru import logger
|
|
9
9
|
from torch.distributed.elastic.multiprocessing.errors import record
|
|
10
|
+
from torch.nn import functional as F
|
|
10
11
|
|
|
11
12
|
from sarasa.activation_checkpoint import apply_op_sac
|
|
12
13
|
from sarasa.checkpoint import Checkpointer
|
|
@@ -48,9 +49,6 @@ class Trainer:
|
|
|
48
49
|
vocab_size = len(self.tokenizer)
|
|
49
50
|
self.config.model.vocab_size = vocab_size
|
|
50
51
|
|
|
51
|
-
# todo: support other loss functions
|
|
52
|
-
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction="sum")
|
|
53
|
-
|
|
54
52
|
# setup model, optimizer, lr scheduler
|
|
55
53
|
with torch.device("meta"), set_dtype(getattr(torch, config.train.dtype)):
|
|
56
54
|
self.model = self.config.model.create()
|
|
@@ -68,9 +66,9 @@ class Trainer:
|
|
|
68
66
|
if config.train.compile:
|
|
69
67
|
logger.info("Compiling the model")
|
|
70
68
|
for block in self.model.blocks:
|
|
71
|
-
block.compile(fullgraph=True)
|
|
69
|
+
block.compile(fullgraph=True, dynamic=False)
|
|
72
70
|
self.model.compile(dynamic=False)
|
|
73
|
-
self.loss_fn.compile()
|
|
71
|
+
self.loss_fn = torch.compile(self.loss_fn, fullgraph=True, dynamic=False)
|
|
74
72
|
|
|
75
73
|
if world_size() > 1:
|
|
76
74
|
apply_distributed(
|
|
@@ -102,7 +100,10 @@ class Trainer:
|
|
|
102
100
|
|
|
103
101
|
self.amp_context = contextlib.nullcontext()
|
|
104
102
|
if config.distributed.name != "fsdp":
|
|
105
|
-
self.amp_context = torch.autocast(
|
|
103
|
+
self.amp_context = torch.autocast(
|
|
104
|
+
device_type=self.device.type,
|
|
105
|
+
dtype=getattr(torch, config.train.amp_dtype),
|
|
106
|
+
)
|
|
106
107
|
|
|
107
108
|
# todo: setup profiler context
|
|
108
109
|
self.profile_context = contextlib.nullcontext()
|
|
@@ -116,40 +117,36 @@ class Trainer:
|
|
|
116
117
|
f"Failed to activate FA4 flash attention: {e}. Install sarasa with `flash_attn` extra for better performance."
|
|
117
118
|
)
|
|
118
119
|
|
|
119
|
-
def __del__(self) -> None:
|
|
120
|
-
# cleanup distributed
|
|
121
|
-
if world_size() > 1:
|
|
122
|
-
try:
|
|
123
|
-
dist.destroy_process_group()
|
|
124
|
-
except Exception as e:
|
|
125
|
-
logger.warning(f"Failed to destroy process group: {e}")
|
|
126
|
-
|
|
127
120
|
@record
|
|
128
121
|
def train(self):
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
self.
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
self.checkpointer
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
self.
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
122
|
+
try:
|
|
123
|
+
logger.info("Starting training...")
|
|
124
|
+
|
|
125
|
+
self.model.train()
|
|
126
|
+
with self.profile_context:
|
|
127
|
+
data_iter = self.batch_generator(self.data_loader)
|
|
128
|
+
for _ in range(self.config.train.steps):
|
|
129
|
+
self.step += 1
|
|
130
|
+
self.gc.collect(self.step)
|
|
131
|
+
try:
|
|
132
|
+
self.train_step(data_iter)
|
|
133
|
+
except StopIteration:
|
|
134
|
+
logger.warning("Data loader exhausted during training.")
|
|
135
|
+
break
|
|
136
|
+
|
|
137
|
+
if self.checkpointer is not None:
|
|
138
|
+
self.checkpointer.save(self.step)
|
|
139
|
+
|
|
140
|
+
if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
|
|
141
|
+
self.evaluate()
|
|
142
|
+
|
|
143
|
+
if world_size() > 1 and self.step == 1:
|
|
144
|
+
update_timeout(self.config.distributed.train_timeout_seconds, self.device)
|
|
145
|
+
|
|
146
|
+
logger.info("Training completed.")
|
|
147
|
+
finally:
|
|
148
|
+
logger.info("Cleaning up trainer...")
|
|
149
|
+
self.close()
|
|
153
150
|
|
|
154
151
|
def batch_generator(
|
|
155
152
|
self,
|
|
@@ -190,7 +187,7 @@ class Trainer:
|
|
|
190
187
|
|
|
191
188
|
with self.amp_context:
|
|
192
189
|
pred = self.model(**input_dict)
|
|
193
|
-
loss = self.loss_fn(pred
|
|
190
|
+
loss = self.loss_fn(pred, target) / valid_tokens
|
|
194
191
|
|
|
195
192
|
del pred
|
|
196
193
|
loss.backward()
|
|
@@ -234,6 +231,18 @@ class Trainer:
|
|
|
234
231
|
},
|
|
235
232
|
)
|
|
236
233
|
|
|
234
|
+
def loss_fn(
|
|
235
|
+
self,
|
|
236
|
+
pred: torch.Tensor,
|
|
237
|
+
target: torch.Tensor,
|
|
238
|
+
) -> torch.Tensor:
|
|
239
|
+
return F.cross_entropy(
|
|
240
|
+
pred.flatten(0, 1).float(),
|
|
241
|
+
target.flatten(0, 1),
|
|
242
|
+
ignore_index=IGNORE_INDEX,
|
|
243
|
+
reduction="sum",
|
|
244
|
+
)
|
|
245
|
+
|
|
237
246
|
def evaluate(self):
|
|
238
247
|
raise NotImplementedError
|
|
239
248
|
|
|
@@ -242,3 +251,20 @@ class Trainer:
|
|
|
242
251
|
batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
|
|
243
252
|
) -> None:
|
|
244
253
|
raise NotImplementedError
|
|
254
|
+
|
|
255
|
+
def close(self) -> None:
|
|
256
|
+
if self.checkpointer is not None:
|
|
257
|
+
self.checkpointer.close()
|
|
258
|
+
|
|
259
|
+
if self.metrics_processor is not None:
|
|
260
|
+
self.metrics_processor.close()
|
|
261
|
+
|
|
262
|
+
# cleanup distributed
|
|
263
|
+
if world_size() > 1:
|
|
264
|
+
try:
|
|
265
|
+
dist.destroy_process_group()
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.warning(f"Failed to destroy process group: {e}")
|
|
268
|
+
|
|
269
|
+
def __del__(self):
|
|
270
|
+
self.close()
|
|
@@ -3,6 +3,7 @@ import gc
|
|
|
3
3
|
import os
|
|
4
4
|
import sys
|
|
5
5
|
import time
|
|
6
|
+
import typing
|
|
6
7
|
from datetime import timedelta
|
|
7
8
|
from functools import cache
|
|
8
9
|
|
|
@@ -11,8 +12,11 @@ from loguru import logger
|
|
|
11
12
|
from torch import distributed as dist
|
|
12
13
|
from torch import nn
|
|
13
14
|
|
|
15
|
+
if typing.TYPE_CHECKING:
|
|
16
|
+
from sarasa.config import Config, Distributed
|
|
14
17
|
|
|
15
|
-
|
|
18
|
+
|
|
19
|
+
def setup_logger(config: Config) -> None:
|
|
16
20
|
logger.remove()
|
|
17
21
|
if config.debug:
|
|
18
22
|
logger_format = f"<blue>RANK={rank()}</blue> | " + (
|
|
@@ -128,7 +132,7 @@ def update_timeout(
|
|
|
128
132
|
|
|
129
133
|
|
|
130
134
|
def apply_distributed(
|
|
131
|
-
config,
|
|
135
|
+
config: Distributed,
|
|
132
136
|
model: nn.Module,
|
|
133
137
|
device: torch.device,
|
|
134
138
|
compile: bool,
|
|
@@ -149,8 +153,8 @@ def apply_distributed(
|
|
|
149
153
|
|
|
150
154
|
# todo: make dtypes configurable
|
|
151
155
|
mp_policy = MixedPrecisionPolicy(
|
|
152
|
-
param_dtype=torch.
|
|
153
|
-
reduce_dtype=torch.
|
|
156
|
+
param_dtype=getattr(torch, config.amp_dtype),
|
|
157
|
+
reduce_dtype=getattr(torch, config.dtype),
|
|
154
158
|
)
|
|
155
159
|
|
|
156
160
|
for block in model.blocks:
|
|
@@ -73,3 +73,10 @@ def test_config_loading_filetype_error(monkeypatch, tmp_path):
|
|
|
73
73
|
monkeypatch.setattr(sys, "argv", ["program", "--config_file", str(config_file)])
|
|
74
74
|
with pytest.raises(ValueError):
|
|
75
75
|
Config.from_cli()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_config_post_init(monkeypatch):
|
|
79
|
+
monkeypatch.setattr(sys, "argv", ["program", "distributed:fsdp"])
|
|
80
|
+
config = Config.from_cli() # just check no error is raised
|
|
81
|
+
assert config.distributed.dtype == config.train.dtype
|
|
82
|
+
assert config.distributed.amp_dtype == config.train.amp_dtype
|
sarasa-0.0.2/sarasa/__init__.py
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|