sarasa 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sarasa/__init__.py +2 -0
- sarasa/activation_checkpoint.py +81 -0
- sarasa/checkpoint.py +112 -0
- sarasa/config.py +279 -0
- sarasa/data/__init__.py +36 -0
- sarasa/data/hf_datasets.py +115 -0
- sarasa/data/tokenizer.py +63 -0
- sarasa/metrics.py +294 -0
- sarasa/models/__init__.py +95 -0
- sarasa/models/attention.py +84 -0
- sarasa/models/llama3.py +129 -0
- sarasa/models/nanochat_gpt.py +192 -0
- sarasa/models/utils.py +39 -0
- sarasa/optimizers/__init__.py +77 -0
- sarasa/optimizers/utils.py +27 -0
- sarasa/trainer.py +244 -0
- sarasa/utils.py +163 -0
- sarasa-0.0.2.dist-info/METADATA +138 -0
- sarasa-0.0.2.dist-info/RECORD +21 -0
- sarasa-0.0.2.dist-info/WHEEL +4 -0
- sarasa-0.0.2.dist-info/licenses/LICENSE +201 -0
sarasa/__init__.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
|
|
5
|
+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
|
|
6
|
+
|
|
7
|
+
# for selective op activation checkpointing
|
|
8
|
+
_ops_sac_save = {
|
|
9
|
+
torch.ops.aten.mm.default,
|
|
10
|
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
|
11
|
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
|
12
|
+
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
|
|
13
|
+
torch.ops.aten._scaled_dot_product_attention_math.default,
|
|
14
|
+
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
|
|
15
|
+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
|
16
|
+
# for low precision training, it's useful to always save
|
|
17
|
+
# the result of max, since the absolute maximum is
|
|
18
|
+
# used to compute the scaling factor for quantization.
|
|
19
|
+
torch.ops.aten.max.default,
|
|
20
|
+
torch._higher_order_ops.inductor_compiled_code,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _op_sac_policy(
|
|
25
|
+
ops_to_save: set,
|
|
26
|
+
mm_recompute_shapes: set | None,
|
|
27
|
+
every_nth_mm: int,
|
|
28
|
+
):
|
|
29
|
+
mm_recompute_shapes = mm_recompute_shapes or set()
|
|
30
|
+
|
|
31
|
+
def _get_custom_policy(meta: dict):
|
|
32
|
+
def _custom_policy(ctx, func, *args, **kwargs):
|
|
33
|
+
# special case, offload to CPU
|
|
34
|
+
if (
|
|
35
|
+
func == torch.ops.aten._to_copy.default
|
|
36
|
+
and "cuda" in str(args[0].device)
|
|
37
|
+
and str(kwargs.get("device", "")) == "cpu"
|
|
38
|
+
):
|
|
39
|
+
return CheckpointPolicy.MUST_SAVE
|
|
40
|
+
|
|
41
|
+
# track mm ops
|
|
42
|
+
mode = "recompute" if ctx.is_recompute else "forward"
|
|
43
|
+
key = f"{mode}_mm_count"
|
|
44
|
+
|
|
45
|
+
if func == torch.ops.aten.mm.default:
|
|
46
|
+
if len(args) > 1 and args[1].shape in mm_recompute_shapes:
|
|
47
|
+
# moe's router
|
|
48
|
+
return CheckpointPolicy.PREFER_RECOMPUTE
|
|
49
|
+
meta[key] += 1
|
|
50
|
+
|
|
51
|
+
# save ops in save list, except every nth mm op
|
|
52
|
+
must_save = (func in ops_to_save) and not (
|
|
53
|
+
func == torch.ops.aten.mm.default and (meta[key] % every_nth_mm == 0)
|
|
54
|
+
)
|
|
55
|
+
return CheckpointPolicy.MUST_SAVE if must_save else CheckpointPolicy.PREFER_RECOMPUTE
|
|
56
|
+
|
|
57
|
+
return _custom_policy
|
|
58
|
+
|
|
59
|
+
def selective_checkpointing_context_fn():
|
|
60
|
+
return create_selective_checkpoint_contexts(_get_custom_policy(defaultdict(int)))
|
|
61
|
+
|
|
62
|
+
return selective_checkpointing_context_fn
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def apply_op_sac(
|
|
66
|
+
model: torch.nn.Module,
|
|
67
|
+
ops_to_save: set | None = None,
|
|
68
|
+
mm_recompute_shapes: set | None = None,
|
|
69
|
+
every_nth_mm: int = 2,
|
|
70
|
+
) -> torch.nn.Module:
|
|
71
|
+
"""Applies selective op activation checkpointing to the given model.
|
|
72
|
+
|
|
73
|
+
Ops like mm is expensive, so we want to store their activations for backward.
|
|
74
|
+
On the other hand, ops like activation functions are cheap, so we prefer to recompute them.
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
ops_to_save = ops_to_save or _ops_sac_save
|
|
78
|
+
return checkpoint_wrapper(
|
|
79
|
+
model,
|
|
80
|
+
_op_sac_policy(ops_to_save, mm_recompute_shapes, every_nth_mm),
|
|
81
|
+
)
|
sarasa/checkpoint.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import gc
|
|
3
|
+
import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
import torch.distributed.checkpoint as dcp
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
|
|
11
|
+
from torch.distributed.checkpoint.state_dict import get_model_state_dict
|
|
12
|
+
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
|
|
13
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
14
|
+
|
|
15
|
+
from sarasa.config import Config
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncMode(enum.StrEnum):
|
|
19
|
+
none = enum.auto()
|
|
20
|
+
default = enum.auto()
|
|
21
|
+
mem_pinned = enum.auto()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ModelWrapper(Stateful):
|
|
25
|
+
def __init__(self, model: torch.nn.Module):
|
|
26
|
+
self.model = model
|
|
27
|
+
|
|
28
|
+
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
29
|
+
return {"model": get_model_state_dict(self.model)}
|
|
30
|
+
|
|
31
|
+
def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
|
|
32
|
+
raise NotImplementedError("...")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Checkpointer:
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
config: Config,
|
|
39
|
+
model: torch.nn.Module,
|
|
40
|
+
):
|
|
41
|
+
self.config = config
|
|
42
|
+
self.checkpoint_freq = config.checkpoint.save_freq
|
|
43
|
+
self.checkpoint_dir = Path(config.output_dir) / "checkpoints"
|
|
44
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
self.async_mode = AsyncMode(config.checkpoint.async_mode)
|
|
46
|
+
if self.async_mode != AsyncMode.none:
|
|
47
|
+
self.pg = dist.new_group(backend="gloo") if dist.is_initialized() else None
|
|
48
|
+
|
|
49
|
+
self.stager = None
|
|
50
|
+
self.save_future = None
|
|
51
|
+
self.stage_future = None
|
|
52
|
+
|
|
53
|
+
self.state = ModelWrapper(model)
|
|
54
|
+
|
|
55
|
+
@torch.no_grad()
|
|
56
|
+
def save(
|
|
57
|
+
self,
|
|
58
|
+
step: int,
|
|
59
|
+
) -> None:
|
|
60
|
+
if step % self.checkpoint_freq != 0:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
begin = time.perf_counter()
|
|
64
|
+
checkpoint_id = str(self.checkpoint_dir / f"checkpoint_{step:09d}")
|
|
65
|
+
|
|
66
|
+
# todo: save other states
|
|
67
|
+
state_dict = self.state.state_dict()
|
|
68
|
+
|
|
69
|
+
if self.async_mode == AsyncMode.default:
|
|
70
|
+
gc.collect(1)
|
|
71
|
+
if self.save_future is not None:
|
|
72
|
+
self.save_future.result()
|
|
73
|
+
self.save_future = dcp.async_save(
|
|
74
|
+
state_dict,
|
|
75
|
+
storage_writer=None,
|
|
76
|
+
checkpoint_id=checkpoint_id,
|
|
77
|
+
process_group=self.pg,
|
|
78
|
+
)
|
|
79
|
+
gc.collect(1)
|
|
80
|
+
elif self.async_mode == AsyncMode.mem_pinned:
|
|
81
|
+
gc.collect(1)
|
|
82
|
+
if self.save_future is not None:
|
|
83
|
+
self.save_future.result()
|
|
84
|
+
if self.stager is None:
|
|
85
|
+
self.stager = DefaultStager(StagingOptions(True, True, True, True))
|
|
86
|
+
ret = dcp.async_save(
|
|
87
|
+
state_dict,
|
|
88
|
+
storage_writer=None,
|
|
89
|
+
checkpoint_id=checkpoint_id,
|
|
90
|
+
process_group=self.pg,
|
|
91
|
+
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
|
|
92
|
+
async_stager=self.stager,
|
|
93
|
+
)
|
|
94
|
+
self.save_future = ret.upload_completion
|
|
95
|
+
self.stage_future = ret.staging_completion
|
|
96
|
+
else:
|
|
97
|
+
ret = dcp.save(
|
|
98
|
+
state_dict,
|
|
99
|
+
storage_writer=None,
|
|
100
|
+
checkpoint_id=checkpoint_id,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
logger.info(f"Finished saving checkpoint at step {step} in {time.perf_counter() - begin:.2f} seconds")
|
|
104
|
+
|
|
105
|
+
def wait_for_staging(self) -> None:
|
|
106
|
+
# no-op if not using mem_pinned async mode
|
|
107
|
+
if self.stage_future is not None:
|
|
108
|
+
self.stage_future.result()
|
|
109
|
+
|
|
110
|
+
def close(self) -> None:
|
|
111
|
+
if self.stager is not None:
|
|
112
|
+
self.stager.close()
|
sarasa/config.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
Variable configuration dataclasses for model, optimizer, lr scheduler, and data
|
|
12
|
+
These classes have `create` methods to instantiate the actual objects
|
|
13
|
+
|
|
14
|
+
Users can define their own configuration dataclasses and pass them to Config.from_cli to use custom components
|
|
15
|
+
"""
|
|
16
|
+
from sarasa.data import DataConfig as Data # noqa
|
|
17
|
+
from sarasa.models import ModelConfig as Model # noqa
|
|
18
|
+
from sarasa.optimizers import AdamW # noqa
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclasses.dataclass
|
|
22
|
+
class LRScheduler:
|
|
23
|
+
warmup_steps: int = 200
|
|
24
|
+
decay_ratio: float | None = None
|
|
25
|
+
"""If set, the ratio of total steps to apply decay after warmup. If None, decay starts immediately after warmup."""
|
|
26
|
+
|
|
27
|
+
decay_type: Literal["linear", "cosine", "sqrt"] = "linear"
|
|
28
|
+
min_lr_factor: float = 0.0
|
|
29
|
+
|
|
30
|
+
def create(
|
|
31
|
+
self,
|
|
32
|
+
optimizer: torch.optim.Optimizer,
|
|
33
|
+
total_iters: int,
|
|
34
|
+
) -> torch.optim.lr_scheduler._LRScheduler:
|
|
35
|
+
assert self.decay_ratio is None or (0 <= self.decay_ratio <= 1), "decay_ratio must be between 0 and 1"
|
|
36
|
+
warmup_steps = self.warmup_steps
|
|
37
|
+
stay_steps = 0 if self.decay_ratio is None else int(total_iters * (1 - self.decay_ratio)) - warmup_steps
|
|
38
|
+
decay_steps = total_iters - warmup_steps - stay_steps
|
|
39
|
+
assert warmup_steps >= 0 and decay_steps >= 0 and stay_steps >= 0, (
|
|
40
|
+
f"Invalid lr scheduler steps configuration: {warmup_steps=}, {decay_steps=}, {stay_steps=}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# 1 / max(1, warmup_steps) to avoid division by zero
|
|
44
|
+
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, 1 / max(1, warmup_steps), total_iters=warmup_steps)
|
|
45
|
+
|
|
46
|
+
stay = torch.optim.lr_scheduler.ConstantLR(optimizer=optimizer, factor=1.0, total_iters=stay_steps)
|
|
47
|
+
|
|
48
|
+
match self.decay_type:
|
|
49
|
+
case "linear":
|
|
50
|
+
decay = torch.optim.lr_scheduler.LinearLR(
|
|
51
|
+
optimizer,
|
|
52
|
+
start_factor=1.0,
|
|
53
|
+
end_factor=self.min_lr_factor,
|
|
54
|
+
total_iters=decay_steps,
|
|
55
|
+
)
|
|
56
|
+
case "sqrt":
|
|
57
|
+
decay = torch.optim.lr_scheduler.LambdaLR(
|
|
58
|
+
optimizer,
|
|
59
|
+
lr_lambda=lambda step: max(
|
|
60
|
+
self.min_lr_factor,
|
|
61
|
+
(decay_steps - step) / decay_steps,
|
|
62
|
+
)
|
|
63
|
+
** 0.5,
|
|
64
|
+
)
|
|
65
|
+
case "cosine":
|
|
66
|
+
decay = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
67
|
+
optimizer,
|
|
68
|
+
T_max=decay_steps,
|
|
69
|
+
eta_min=optimizer.param_groups[0]["lr"] * self.min_lr_factor,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
|
73
|
+
optimizer,
|
|
74
|
+
[warmup, stay, decay],
|
|
75
|
+
[self.warmup_steps, self.warmup_steps + stay_steps],
|
|
76
|
+
)
|
|
77
|
+
return scheduler
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
Static configuration dataclasses
|
|
82
|
+
|
|
83
|
+
These classes are not expected to be changed by the user
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclasses.dataclass
|
|
88
|
+
class Train:
|
|
89
|
+
steps: int = 10_000
|
|
90
|
+
|
|
91
|
+
grad_clip: float | None = None
|
|
92
|
+
|
|
93
|
+
dtype: Literal["bfloat16", "float32"] = "float32"
|
|
94
|
+
|
|
95
|
+
compile: bool = False
|
|
96
|
+
|
|
97
|
+
gc_freq: int = 50
|
|
98
|
+
"""Garbage collection frequency (in steps). If -1, no periodic GC is performed."""
|
|
99
|
+
|
|
100
|
+
local_batch_size: int = 32
|
|
101
|
+
"""local (per device) batch size"""
|
|
102
|
+
|
|
103
|
+
global_batch_size: int = 256
|
|
104
|
+
"""
|
|
105
|
+
global (across all devices) batch size, used to compute
|
|
106
|
+
grad_accum_steps = global_batch_size // (local_batch_size * num_devices)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
use_fa4: bool = True
|
|
110
|
+
"""Whether to use FA4 flash attention if available."""
|
|
111
|
+
|
|
112
|
+
val_freq: int = -1
|
|
113
|
+
"""Validation frequency (in steps). If -1, no validation is performed."""
|
|
114
|
+
|
|
115
|
+
use_sac: bool = False
|
|
116
|
+
"""Whether to use selective activation checkpointing."""
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclasses.dataclass
|
|
120
|
+
class Metrics:
|
|
121
|
+
log_freq: int = 10
|
|
122
|
+
use_tensorboard: bool = False
|
|
123
|
+
all_node: bool = False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclasses.dataclass
|
|
127
|
+
class Checkpoint:
|
|
128
|
+
save_freq: int = 1000
|
|
129
|
+
async_mode: Literal["none", "default", "mem_pinned"] = "default"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclasses.dataclass
|
|
133
|
+
class Distributed:
|
|
134
|
+
backend: Literal["nccl", "gloo"] = "nccl"
|
|
135
|
+
|
|
136
|
+
init_timeout_seconds: int = 300
|
|
137
|
+
"""Timeout for initializing the distributed process group."""
|
|
138
|
+
|
|
139
|
+
train_timeout_seconds: int = 100
|
|
140
|
+
"""Timeout for distributed training operations after the first iteration."""
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def name(self) -> str:
|
|
144
|
+
return self.__class__.__name__.lower()
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclasses.dataclass
|
|
148
|
+
class DDP(Distributed):
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclasses.dataclass
|
|
153
|
+
class FSDP(Distributed):
|
|
154
|
+
reshard_after_forward: bool = False
|
|
155
|
+
"""Whether to reshard model parameters after each forward pass (FSDP only)."""
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@dataclasses.dataclass
|
|
159
|
+
class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
160
|
+
# variable components
|
|
161
|
+
model: ModelT
|
|
162
|
+
optim: OptimizerT
|
|
163
|
+
lr_scheduler: LRSchedulerT
|
|
164
|
+
data: DataT
|
|
165
|
+
|
|
166
|
+
# static components
|
|
167
|
+
train: Train = dataclasses.field(default_factory=Train)
|
|
168
|
+
metrics: Metrics = dataclasses.field(default_factory=Metrics)
|
|
169
|
+
checkpoint: Checkpoint = dataclasses.field(default_factory=Checkpoint)
|
|
170
|
+
distributed: DDP | FSDP = dataclasses.field(default_factory=DDP)
|
|
171
|
+
|
|
172
|
+
seed: int = 0
|
|
173
|
+
debug: bool = False
|
|
174
|
+
""" Enable debug mode with more verbose logging and checks."""
|
|
175
|
+
|
|
176
|
+
output_dir: Path | str = Path("./outputs")
|
|
177
|
+
"""Directory to save checkpoints and logs."""
|
|
178
|
+
|
|
179
|
+
config_file: Path | str | None = None
|
|
180
|
+
"""Path to a config file (JSON or TOML) to load configuration from."""
|
|
181
|
+
|
|
182
|
+
def __post_init__(self):
|
|
183
|
+
if self.output_dir is not None:
|
|
184
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
185
|
+
|
|
186
|
+
if hasattr(self.model, "seq_len") and self.model.seq_len is None:
|
|
187
|
+
if self.data.seq_len is not None:
|
|
188
|
+
self.model.seq_len = self.data.seq_len
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError("Either model.seq_len or data.seq_len must be set.")
|
|
191
|
+
|
|
192
|
+
@classmethod
|
|
193
|
+
def create(
|
|
194
|
+
cls,
|
|
195
|
+
model: ModelT,
|
|
196
|
+
optim: OptimizerT,
|
|
197
|
+
lr_scheduler: LRSchedulerT,
|
|
198
|
+
data: DataT,
|
|
199
|
+
**kwargs,
|
|
200
|
+
) -> Config:
|
|
201
|
+
return cls(
|
|
202
|
+
model=model,
|
|
203
|
+
optim=optim,
|
|
204
|
+
lr_scheduler=lr_scheduler,
|
|
205
|
+
data=data,
|
|
206
|
+
**kwargs,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
@classmethod
|
|
210
|
+
def from_cli(
|
|
211
|
+
cls,
|
|
212
|
+
*,
|
|
213
|
+
model_type: ModelT = Model,
|
|
214
|
+
optim_type: OptimizerT = AdamW,
|
|
215
|
+
lr_scheduler_type: LRSchedulerT = LRScheduler,
|
|
216
|
+
data_type: DataT = Data,
|
|
217
|
+
) -> Config:
|
|
218
|
+
"""
|
|
219
|
+
initialize JobConfig from command line arguments
|
|
220
|
+
update the values with the following priority: CLI arguments > config file > defaults
|
|
221
|
+
|
|
222
|
+
*_type can be used to specify custom dataclass types for each section
|
|
223
|
+
>> config = Config.from_cli(optim_type=CustomOptimizerConfig)
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
import importlib.util
|
|
227
|
+
|
|
228
|
+
import tyro
|
|
229
|
+
|
|
230
|
+
loaded_config = None
|
|
231
|
+
|
|
232
|
+
if (under := ("--config_file" in sys.argv)) or ("--config-file" in sys.argv):
|
|
233
|
+
config_file = sys.argv[sys.argv.index("--config_file" if under else "--config-file") + 1]
|
|
234
|
+
config_file = Path(config_file)
|
|
235
|
+
|
|
236
|
+
if not config_file.exists():
|
|
237
|
+
raise FileNotFoundError(f"Config file {config_file} does not exist.")
|
|
238
|
+
|
|
239
|
+
if config_file.suffix != ".py":
|
|
240
|
+
raise ValueError("Only Python config files are supported in this method.")
|
|
241
|
+
|
|
242
|
+
spec = importlib.util.spec_from_file_location("custom_config", config_file)
|
|
243
|
+
module = importlib.util.module_from_spec(spec)
|
|
244
|
+
spec.loader.exec_module(module)
|
|
245
|
+
configs = [
|
|
246
|
+
config
|
|
247
|
+
for config in module.__dict__.values()
|
|
248
|
+
if isinstance(config, cls) and not isinstance(config, type)
|
|
249
|
+
]
|
|
250
|
+
if len(configs) == 0:
|
|
251
|
+
raise ValueError(f"No Config instance found in {config_file}.")
|
|
252
|
+
elif len(configs) > 1:
|
|
253
|
+
raise ValueError(f"Multiple Config instances found in {config_file}. Please keep only one.")
|
|
254
|
+
else:
|
|
255
|
+
loaded_config = configs[0]
|
|
256
|
+
|
|
257
|
+
return tyro.cli(
|
|
258
|
+
cls[
|
|
259
|
+
model_type,
|
|
260
|
+
optim_type,
|
|
261
|
+
lr_scheduler_type,
|
|
262
|
+
data_type,
|
|
263
|
+
],
|
|
264
|
+
default=loaded_config,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
__all__ = [
|
|
269
|
+
"Config",
|
|
270
|
+
"Model",
|
|
271
|
+
"AdamW",
|
|
272
|
+
"LRScheduler",
|
|
273
|
+
"Data",
|
|
274
|
+
"Train",
|
|
275
|
+
"Metrics",
|
|
276
|
+
"Checkpoint",
|
|
277
|
+
"DDP",
|
|
278
|
+
"FSDP",
|
|
279
|
+
]
|
sarasa/data/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from torch.utils.data import DataLoader
|
|
6
|
+
|
|
7
|
+
from sarasa.data.hf_datasets import Datasets, HFTextDataset
|
|
8
|
+
from sarasa.data.tokenizer import HFTokenizerWrapper
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclasses.dataclass
|
|
12
|
+
class DataConfig:
|
|
13
|
+
dataset: Datasets = Datasets.fineweb_edu_100b
|
|
14
|
+
"""Dataset to use for training. Can be a predefined dataset or a custom dataset path."""
|
|
15
|
+
|
|
16
|
+
tokenizer_path: Path | str = Path("./tokenizer")
|
|
17
|
+
"""Path to `tokenizer.json` and `tokenizer_config.json` files."""
|
|
18
|
+
|
|
19
|
+
seq_len: int = 2048
|
|
20
|
+
|
|
21
|
+
num_workers: int = 4
|
|
22
|
+
pin_memory: bool = True
|
|
23
|
+
|
|
24
|
+
def create(
|
|
25
|
+
self,
|
|
26
|
+
batch_size: int,
|
|
27
|
+
) -> dict[str, Any]:
|
|
28
|
+
# return {"tokenizer": tokenizer, "train_loader": train_loader, "val_loader": val_loader | None}
|
|
29
|
+
tokenizer = HFTokenizerWrapper(Path(self.tokenizer_path))
|
|
30
|
+
ds = HFTextDataset(self.dataset, "train", tokenizer, self.seq_len)
|
|
31
|
+
data_loader = DataLoader(ds, batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory)
|
|
32
|
+
|
|
33
|
+
return {
|
|
34
|
+
"tokenizer": tokenizer,
|
|
35
|
+
"train_loader": data_loader,
|
|
36
|
+
}
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
from typing import Any, Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from datasets import disable_progress_bars, load_dataset
|
|
6
|
+
from datasets.distributed import split_dataset_by_node
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from torch.utils.data import IterableDataset
|
|
9
|
+
|
|
10
|
+
from sarasa.utils import rank, world_size
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Datasets(enum.StrEnum):
|
|
14
|
+
c4 = enum.auto()
|
|
15
|
+
fineweb_edu = enum.auto()
|
|
16
|
+
fineweb_edu_100b = enum.auto()
|
|
17
|
+
fineweb_edu_dedup = enum.auto()
|
|
18
|
+
|
|
19
|
+
def load(
|
|
20
|
+
self,
|
|
21
|
+
cache_dir: str | None,
|
|
22
|
+
) -> Any:
|
|
23
|
+
match self:
|
|
24
|
+
case Datasets.c4:
|
|
25
|
+
return load_dataset(
|
|
26
|
+
"allenai/c4",
|
|
27
|
+
name="en",
|
|
28
|
+
split="train",
|
|
29
|
+
streaming=True,
|
|
30
|
+
cache_dir=cache_dir,
|
|
31
|
+
)
|
|
32
|
+
case Datasets.fineweb_edu:
|
|
33
|
+
return load_dataset(
|
|
34
|
+
"HuggingFaceFW/fineweb-edu",
|
|
35
|
+
name="default",
|
|
36
|
+
split="train",
|
|
37
|
+
streaming=True,
|
|
38
|
+
cache_dir=cache_dir,
|
|
39
|
+
)
|
|
40
|
+
case Datasets.fineweb_edu_100b:
|
|
41
|
+
return load_dataset(
|
|
42
|
+
"HuggingFaceFW/fineweb-edu",
|
|
43
|
+
name="sample-100BT",
|
|
44
|
+
split="train",
|
|
45
|
+
streaming=True,
|
|
46
|
+
cache_dir=cache_dir,
|
|
47
|
+
)
|
|
48
|
+
case Datasets.fineweb_edu_dedup:
|
|
49
|
+
return load_dataset(
|
|
50
|
+
"HuggingFaceTB/smollm-corpus",
|
|
51
|
+
"fineweb-edu-dedup",
|
|
52
|
+
split="train",
|
|
53
|
+
streaming=True,
|
|
54
|
+
cache_dir=cache_dir,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class HFTextDataset(IterableDataset):
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
dataset_name: Datasets | str,
|
|
62
|
+
split: str,
|
|
63
|
+
tokenizer: Callable[[str], list[int]],
|
|
64
|
+
seq_len: int,
|
|
65
|
+
infinite: bool = True,
|
|
66
|
+
cache_dir: str | None = None,
|
|
67
|
+
):
|
|
68
|
+
if rank() != 0:
|
|
69
|
+
disable_progress_bars()
|
|
70
|
+
self.dataset_name = dataset_name
|
|
71
|
+
if dataset_name in Datasets:
|
|
72
|
+
ds = Datasets(dataset_name).load(cache_dir=cache_dir)
|
|
73
|
+
|
|
74
|
+
else:
|
|
75
|
+
logger.warning(f"Unknown dataset: {dataset_name}. Trying to use `load_dataset` directly.")
|
|
76
|
+
ds = load_dataset(dataset_name, split=split, streaming=True, cache_dir=cache_dir)
|
|
77
|
+
|
|
78
|
+
self.data = split_dataset_by_node(ds, rank=rank(), world_size=world_size())
|
|
79
|
+
self.tokenizer = tokenizer
|
|
80
|
+
self.seq_len = seq_len
|
|
81
|
+
self.token_buffer: list[int] = []
|
|
82
|
+
|
|
83
|
+
def _text_processor(
|
|
84
|
+
self,
|
|
85
|
+
sample: dict,
|
|
86
|
+
) -> str:
|
|
87
|
+
# Default text processor: extract 'text' field
|
|
88
|
+
return sample["text"]
|
|
89
|
+
|
|
90
|
+
def __iter__(self):
|
|
91
|
+
max_buffer_token_len = 1 + self.seq_len
|
|
92
|
+
|
|
93
|
+
while True:
|
|
94
|
+
for sample in iter(self.data):
|
|
95
|
+
# Use the dataset-specific text processor
|
|
96
|
+
sample_text = self._text_processor(sample)
|
|
97
|
+
sample_tokens = self.tokenizer.encode(sample_text)
|
|
98
|
+
self.token_buffer.extend(sample_tokens)
|
|
99
|
+
|
|
100
|
+
while len(self.token_buffer) >= max_buffer_token_len:
|
|
101
|
+
x = torch.LongTensor(self.token_buffer[:max_buffer_token_len])
|
|
102
|
+
# update tokens to the remaining tokens
|
|
103
|
+
self.token_buffer = self.token_buffer[max_buffer_token_len:]
|
|
104
|
+
input = x[:-1]
|
|
105
|
+
label = x[1:]
|
|
106
|
+
yield {"input": input}, label
|
|
107
|
+
|
|
108
|
+
if not self.infinite:
|
|
109
|
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
|
110
|
+
break
|
|
111
|
+
else:
|
|
112
|
+
# Reset offset for the next iteration
|
|
113
|
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
|
114
|
+
if hasattr(self.data, "set_epoch") and hasattr(self.data, "epoch"):
|
|
115
|
+
self.data.set_epoch(self.data.epoch + 1)
|
sarasa/data/tokenizer.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from tokenizers import Tokenizer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseTokenizerWrapper:
|
|
8
|
+
def encode(self, *args, **kwargs) -> list[int]:
|
|
9
|
+
raise NotImplementedError
|
|
10
|
+
|
|
11
|
+
def decode(self, *args, **kwargs) -> str:
|
|
12
|
+
raise NotImplementedError
|
|
13
|
+
|
|
14
|
+
def __len__(self) -> int:
|
|
15
|
+
raise NotImplementedError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class HFTokenizerWrapper(BaseTokenizerWrapper):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
tokenizer_path: Path,
|
|
22
|
+
):
|
|
23
|
+
self.tokenizer = Tokenizer.from_file(str(tokenizer_path / "tokenizer.json"))
|
|
24
|
+
with (tokenizer_path / "tokenizer_config.json").open("r") as f:
|
|
25
|
+
config = json.load(f)
|
|
26
|
+
|
|
27
|
+
bos_token = self._get_tokens_from_config(config.get("bos_token", None))
|
|
28
|
+
if bos_token is None:
|
|
29
|
+
raise ValueError("BOS token must be specified in the tokenizer config.")
|
|
30
|
+
|
|
31
|
+
# check if tokenizer adds bos token automatically
|
|
32
|
+
test_encoding = self.tokenizer.encode("test").ids
|
|
33
|
+
self.bos_token_id = self.tokenizer.token_to_id(bos_token)
|
|
34
|
+
self.need_bos = self.bos_token_id not in test_encoding
|
|
35
|
+
|
|
36
|
+
def _get_tokens_from_config(
|
|
37
|
+
self,
|
|
38
|
+
token: dict[str, str] | str | None,
|
|
39
|
+
) -> str | None:
|
|
40
|
+
if isinstance(token, dict):
|
|
41
|
+
token = token["content"]
|
|
42
|
+
return token
|
|
43
|
+
|
|
44
|
+
def encode(
|
|
45
|
+
self,
|
|
46
|
+
text: str,
|
|
47
|
+
) -> list[int]:
|
|
48
|
+
token_ids = self.tokenizer.encode(text).ids
|
|
49
|
+
|
|
50
|
+
if self.need_bos:
|
|
51
|
+
token_ids = [self.bos_token_id] + token_ids
|
|
52
|
+
|
|
53
|
+
return token_ids
|
|
54
|
+
|
|
55
|
+
def decode(
|
|
56
|
+
self,
|
|
57
|
+
token_ids: list[int],
|
|
58
|
+
**kwargs,
|
|
59
|
+
) -> str:
|
|
60
|
+
return self.tokenizer.decode(token_ids, **kwargs)
|
|
61
|
+
|
|
62
|
+
def __len__(self) -> int:
|
|
63
|
+
return self.tokenizer.get_vocab_size()
|