sarasa 0.0.2__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sarasa-0.0.2 → sarasa-0.0.3}/PKG-INFO +1 -1
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/config.py +21 -4
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/trainer.py +39 -25
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/utils.py +2 -2
- {sarasa-0.0.2 → sarasa-0.0.3}/tests/test_config.py +7 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/.github/workflows/pypi.yaml +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/.github/workflows/tests_and_lint.yaml +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/.gitignore +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/LICENSE +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/README.md +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/configs/example.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/configs/llama3-1b.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/main.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/pyproject.toml +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/__init__.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/activation_checkpoint.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/checkpoint.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/data/__init__.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/data/hf_datasets.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/data/tokenizer.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/metrics.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/__init__.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/attention.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/llama3.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/nanochat_gpt.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/utils.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/optimizers/__init__.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/optimizers/utils.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/tests/test_model.py +0 -0
- {sarasa-0.0.2 → sarasa-0.0.3}/tests/test_utils.py +0 -0
|
@@ -91,6 +91,10 @@ class Train:
|
|
|
91
91
|
grad_clip: float | None = None
|
|
92
92
|
|
|
93
93
|
dtype: Literal["bfloat16", "float32"] = "float32"
|
|
94
|
+
"""Dtype used for model initialization"""
|
|
95
|
+
|
|
96
|
+
amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
|
|
97
|
+
"""Dtype used for automatic mixed precision training"""
|
|
94
98
|
|
|
95
99
|
compile: bool = False
|
|
96
100
|
|
|
@@ -154,6 +158,12 @@ class FSDP(Distributed):
|
|
|
154
158
|
reshard_after_forward: bool = False
|
|
155
159
|
"""Whether to reshard model parameters after each forward pass (FSDP only)."""
|
|
156
160
|
|
|
161
|
+
dtype: str | None = None
|
|
162
|
+
"""Dtype for FSDP reduce operations. If None, uses train.dtype."""
|
|
163
|
+
|
|
164
|
+
amp_dtype: str | None = None
|
|
165
|
+
"""Dtype for FSDP parameter storage. If None, uses train.amp_dtype."""
|
|
166
|
+
|
|
157
167
|
|
|
158
168
|
@dataclasses.dataclass
|
|
159
169
|
class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
@@ -183,11 +193,15 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
|
183
193
|
if self.output_dir is not None:
|
|
184
194
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
185
195
|
|
|
186
|
-
if hasattr(self.model, "seq_len")
|
|
187
|
-
if self.data.seq_len is not None:
|
|
196
|
+
if hasattr(self.model, "seq_len"):
|
|
197
|
+
if self.model.seq_len is None and self.data.seq_len is not None:
|
|
188
198
|
self.model.seq_len = self.data.seq_len
|
|
189
|
-
|
|
190
|
-
raise ValueError("
|
|
199
|
+
if self.model.seq_len is None:
|
|
200
|
+
raise ValueError("seq_len must be specified in either model or data configuration.")
|
|
201
|
+
|
|
202
|
+
if isinstance(self.distributed, FSDP):
|
|
203
|
+
self.distributed.dtype = self.distributed.dtype or self.train.dtype
|
|
204
|
+
self.distributed.amp_dtype = self.distributed.amp_dtype or self.train.amp_dtype
|
|
191
205
|
|
|
192
206
|
@classmethod
|
|
193
207
|
def create(
|
|
@@ -227,6 +241,8 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
|
227
241
|
|
|
228
242
|
import tyro
|
|
229
243
|
|
|
244
|
+
from sarasa.utils import rank
|
|
245
|
+
|
|
230
246
|
loaded_config = None
|
|
231
247
|
|
|
232
248
|
if (under := ("--config_file" in sys.argv)) or ("--config-file" in sys.argv):
|
|
@@ -262,6 +278,7 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
|
|
|
262
278
|
data_type,
|
|
263
279
|
],
|
|
264
280
|
default=loaded_config,
|
|
281
|
+
console_outputs=(rank() == 0),
|
|
265
282
|
)
|
|
266
283
|
|
|
267
284
|
|
|
@@ -102,7 +102,10 @@ class Trainer:
|
|
|
102
102
|
|
|
103
103
|
self.amp_context = contextlib.nullcontext()
|
|
104
104
|
if config.distributed.name != "fsdp":
|
|
105
|
-
self.amp_context = torch.autocast(
|
|
105
|
+
self.amp_context = torch.autocast(
|
|
106
|
+
device_type=self.device.type,
|
|
107
|
+
dtype=getattr(torch, config.train.amp_dtype),
|
|
108
|
+
)
|
|
106
109
|
|
|
107
110
|
# todo: setup profiler context
|
|
108
111
|
self.profile_context = contextlib.nullcontext()
|
|
@@ -126,30 +129,34 @@ class Trainer:
|
|
|
126
129
|
|
|
127
130
|
@record
|
|
128
131
|
def train(self):
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
self.
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
self.checkpointer
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
self.
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
132
|
+
try:
|
|
133
|
+
logger.info("Starting training...")
|
|
134
|
+
|
|
135
|
+
self.model.train()
|
|
136
|
+
with self.profile_context:
|
|
137
|
+
data_iter = self.batch_generator(self.data_loader)
|
|
138
|
+
for _ in range(self.config.train.steps):
|
|
139
|
+
self.step += 1
|
|
140
|
+
self.gc.collect(self.step)
|
|
141
|
+
try:
|
|
142
|
+
self.train_step(data_iter)
|
|
143
|
+
except StopIteration:
|
|
144
|
+
logger.warning("Data loader exhausted during training.")
|
|
145
|
+
break
|
|
146
|
+
|
|
147
|
+
if self.checkpointer is not None:
|
|
148
|
+
self.checkpointer.save(self.step)
|
|
149
|
+
|
|
150
|
+
if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
|
|
151
|
+
self.evaluate()
|
|
152
|
+
|
|
153
|
+
if world_size() > 1 and self.step == 1:
|
|
154
|
+
update_timeout(self.config.distributed.train_timeout_seconds, self.device)
|
|
155
|
+
|
|
156
|
+
logger.info("Training completed.")
|
|
157
|
+
finally:
|
|
158
|
+
logger.info("Cleaning up trainer...")
|
|
159
|
+
self.close()
|
|
153
160
|
|
|
154
161
|
def batch_generator(
|
|
155
162
|
self,
|
|
@@ -242,3 +249,10 @@ class Trainer:
|
|
|
242
249
|
batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
|
|
243
250
|
) -> None:
|
|
244
251
|
raise NotImplementedError
|
|
252
|
+
|
|
253
|
+
def close(self) -> None:
|
|
254
|
+
if self.checkpointer is not None:
|
|
255
|
+
self.checkpointer.close()
|
|
256
|
+
|
|
257
|
+
if self.metrics_processor is not None:
|
|
258
|
+
self.metrics_processor.close()
|
|
@@ -149,8 +149,8 @@ def apply_distributed(
|
|
|
149
149
|
|
|
150
150
|
# todo: make dtypes configurable
|
|
151
151
|
mp_policy = MixedPrecisionPolicy(
|
|
152
|
-
param_dtype=torch.
|
|
153
|
-
reduce_dtype=torch.
|
|
152
|
+
param_dtype=getattr(torch, config.amp_dtype),
|
|
153
|
+
reduce_dtype=getattr(torch, config.dtype),
|
|
154
154
|
)
|
|
155
155
|
|
|
156
156
|
for block in model.blocks:
|
|
@@ -73,3 +73,10 @@ def test_config_loading_filetype_error(monkeypatch, tmp_path):
|
|
|
73
73
|
monkeypatch.setattr(sys, "argv", ["program", "--config_file", str(config_file)])
|
|
74
74
|
with pytest.raises(ValueError):
|
|
75
75
|
Config.from_cli()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_config_post_init(monkeypatch):
|
|
79
|
+
monkeypatch.setattr(sys, "argv", ["program", "distributed:fsdp"])
|
|
80
|
+
config = Config.from_cli() # just check no error is raised
|
|
81
|
+
assert config.distributed.dtype == config.train.dtype
|
|
82
|
+
assert config.distributed.amp_dtype == config.train.amp_dtype
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|