sarasa 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {sarasa-0.0.2 → sarasa-0.0.3}/PKG-INFO +1 -1
  2. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/config.py +21 -4
  3. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/trainer.py +39 -25
  4. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/utils.py +2 -2
  5. {sarasa-0.0.2 → sarasa-0.0.3}/tests/test_config.py +7 -0
  6. {sarasa-0.0.2 → sarasa-0.0.3}/.github/workflows/pypi.yaml +0 -0
  7. {sarasa-0.0.2 → sarasa-0.0.3}/.github/workflows/tests_and_lint.yaml +0 -0
  8. {sarasa-0.0.2 → sarasa-0.0.3}/.gitignore +0 -0
  9. {sarasa-0.0.2 → sarasa-0.0.3}/LICENSE +0 -0
  10. {sarasa-0.0.2 → sarasa-0.0.3}/README.md +0 -0
  11. {sarasa-0.0.2 → sarasa-0.0.3}/configs/example.py +0 -0
  12. {sarasa-0.0.2 → sarasa-0.0.3}/configs/llama3-1b.py +0 -0
  13. {sarasa-0.0.2 → sarasa-0.0.3}/main.py +0 -0
  14. {sarasa-0.0.2 → sarasa-0.0.3}/pyproject.toml +0 -0
  15. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/__init__.py +0 -0
  16. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/activation_checkpoint.py +0 -0
  17. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/checkpoint.py +0 -0
  18. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/data/__init__.py +0 -0
  19. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/data/hf_datasets.py +0 -0
  20. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/data/tokenizer.py +0 -0
  21. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/metrics.py +0 -0
  22. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/__init__.py +0 -0
  23. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/attention.py +0 -0
  24. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/llama3.py +0 -0
  25. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/nanochat_gpt.py +0 -0
  26. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/models/utils.py +0 -0
  27. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/optimizers/__init__.py +0 -0
  28. {sarasa-0.0.2 → sarasa-0.0.3}/sarasa/optimizers/utils.py +0 -0
  29. {sarasa-0.0.2 → sarasa-0.0.3}/tests/test_model.py +0 -0
  30. {sarasa-0.0.2 → sarasa-0.0.3}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sarasa
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Add your description here
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.13
@@ -91,6 +91,10 @@ class Train:
91
91
  grad_clip: float | None = None
92
92
 
93
93
  dtype: Literal["bfloat16", "float32"] = "float32"
94
+ """Dtype used for model initialization"""
95
+
96
+ amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
97
+ """Dtype used for automatic mixed precision training"""
94
98
 
95
99
  compile: bool = False
96
100
 
@@ -154,6 +158,12 @@ class FSDP(Distributed):
154
158
  reshard_after_forward: bool = False
155
159
  """Whether to reshard model parameters after each forward pass (FSDP only)."""
156
160
 
161
+ dtype: str | None = None
162
+ """Dtype for FSDP reduce operations. If None, uses train.dtype."""
163
+
164
+ amp_dtype: str | None = None
165
+ """Dtype for FSDP parameter storage. If None, uses train.amp_dtype."""
166
+
157
167
 
158
168
  @dataclasses.dataclass
159
169
  class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
@@ -183,11 +193,15 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
183
193
  if self.output_dir is not None:
184
194
  self.output_dir.mkdir(parents=True, exist_ok=True)
185
195
 
186
- if hasattr(self.model, "seq_len") and self.model.seq_len is None:
187
- if self.data.seq_len is not None:
196
+ if hasattr(self.model, "seq_len"):
197
+ if self.model.seq_len is None and self.data.seq_len is not None:
188
198
  self.model.seq_len = self.data.seq_len
189
- else:
190
- raise ValueError("Either model.seq_len or data.seq_len must be set.")
199
+ if self.model.seq_len is None:
200
+ raise ValueError("seq_len must be specified in either model or data configuration.")
201
+
202
+ if isinstance(self.distributed, FSDP):
203
+ self.distributed.dtype = self.distributed.dtype or self.train.dtype
204
+ self.distributed.amp_dtype = self.distributed.amp_dtype or self.train.amp_dtype
191
205
 
192
206
  @classmethod
193
207
  def create(
@@ -227,6 +241,8 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
227
241
 
228
242
  import tyro
229
243
 
244
+ from sarasa.utils import rank
245
+
230
246
  loaded_config = None
231
247
 
232
248
  if (under := ("--config_file" in sys.argv)) or ("--config-file" in sys.argv):
@@ -262,6 +278,7 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
262
278
  data_type,
263
279
  ],
264
280
  default=loaded_config,
281
+ console_outputs=(rank() == 0),
265
282
  )
266
283
 
267
284
 
@@ -102,7 +102,10 @@ class Trainer:
102
102
 
103
103
  self.amp_context = contextlib.nullcontext()
104
104
  if config.distributed.name != "fsdp":
105
- self.amp_context = torch.autocast(device_type=self.device.type, dtype=getattr(torch, config.train.dtype))
105
+ self.amp_context = torch.autocast(
106
+ device_type=self.device.type,
107
+ dtype=getattr(torch, config.train.amp_dtype),
108
+ )
106
109
 
107
110
  # todo: setup profiler context
108
111
  self.profile_context = contextlib.nullcontext()
@@ -126,30 +129,34 @@ class Trainer:
126
129
 
127
130
  @record
128
131
  def train(self):
129
- logger.info("Starting training...")
130
-
131
- self.model.train()
132
- with self.profile_context:
133
- data_iter = self.batch_generator(self.data_loader)
134
- for _ in range(self.config.train.steps):
135
- self.step += 1
136
- self.gc.collect(self.step)
137
- try:
138
- self.train_step(data_iter)
139
- except StopIteration:
140
- logger.warning("Data loader exhausted during training.")
141
- break
142
-
143
- if self.checkpointer is not None:
144
- self.checkpointer.save(self.step)
145
-
146
- if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
147
- self.evaluate()
148
-
149
- if world_size() > 1 and self.step == 1:
150
- update_timeout(self.config.distributed.train_timeout_seconds, self.device)
151
-
152
- logger.info("Training completed.")
132
+ try:
133
+ logger.info("Starting training...")
134
+
135
+ self.model.train()
136
+ with self.profile_context:
137
+ data_iter = self.batch_generator(self.data_loader)
138
+ for _ in range(self.config.train.steps):
139
+ self.step += 1
140
+ self.gc.collect(self.step)
141
+ try:
142
+ self.train_step(data_iter)
143
+ except StopIteration:
144
+ logger.warning("Data loader exhausted during training.")
145
+ break
146
+
147
+ if self.checkpointer is not None:
148
+ self.checkpointer.save(self.step)
149
+
150
+ if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
151
+ self.evaluate()
152
+
153
+ if world_size() > 1 and self.step == 1:
154
+ update_timeout(self.config.distributed.train_timeout_seconds, self.device)
155
+
156
+ logger.info("Training completed.")
157
+ finally:
158
+ logger.info("Cleaning up trainer...")
159
+ self.close()
153
160
 
154
161
  def batch_generator(
155
162
  self,
@@ -242,3 +249,10 @@ class Trainer:
242
249
  batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
243
250
  ) -> None:
244
251
  raise NotImplementedError
252
+
253
+ def close(self) -> None:
254
+ if self.checkpointer is not None:
255
+ self.checkpointer.close()
256
+
257
+ if self.metrics_processor is not None:
258
+ self.metrics_processor.close()
@@ -149,8 +149,8 @@ def apply_distributed(
149
149
 
150
150
  # todo: make dtypes configurable
151
151
  mp_policy = MixedPrecisionPolicy(
152
- param_dtype=torch.bfloat16,
153
- reduce_dtype=torch.float32,
152
+ param_dtype=getattr(torch, config.amp_dtype),
153
+ reduce_dtype=getattr(torch, config.dtype),
154
154
  )
155
155
 
156
156
  for block in model.blocks:
@@ -73,3 +73,10 @@ def test_config_loading_filetype_error(monkeypatch, tmp_path):
73
73
  monkeypatch.setattr(sys, "argv", ["program", "--config_file", str(config_file)])
74
74
  with pytest.raises(ValueError):
75
75
  Config.from_cli()
76
+
77
+
78
+ def test_config_post_init(monkeypatch):
79
+ monkeypatch.setattr(sys, "argv", ["program", "distributed:fsdp"])
80
+ config = Config.from_cli() # just check no error is raised
81
+ assert config.distributed.dtype == config.train.dtype
82
+ assert config.distributed.amp_dtype == config.train.amp_dtype
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes