sarasa 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {sarasa-0.0.2 → sarasa-0.0.4}/PKG-INFO +10 -4
  2. {sarasa-0.0.2 → sarasa-0.0.4}/README.md +9 -3
  3. {sarasa-0.0.2 → sarasa-0.0.4}/configs/example.py +5 -2
  4. {sarasa-0.0.2 → sarasa-0.0.4}/configs/llama3-1b.py +4 -3
  5. sarasa-0.0.4/sarasa/__init__.py +11 -0
  6. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/checkpoint.py +10 -0
  7. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/config.py +21 -4
  8. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/__init__.py +2 -0
  9. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/attention.py +6 -2
  10. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/llama3.py +18 -7
  11. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/nanochat_gpt.py +10 -1
  12. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/models/utils.py +0 -9
  13. sarasa-0.0.2/sarasa/trainer.py → sarasa-0.0.4/sarasa/train.py +65 -39
  14. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/utils.py +8 -4
  15. {sarasa-0.0.2 → sarasa-0.0.4}/tests/test_config.py +7 -0
  16. sarasa-0.0.2/sarasa/__init__.py +0 -2
  17. {sarasa-0.0.2 → sarasa-0.0.4}/.github/workflows/pypi.yaml +0 -0
  18. {sarasa-0.0.2 → sarasa-0.0.4}/.github/workflows/tests_and_lint.yaml +0 -0
  19. {sarasa-0.0.2 → sarasa-0.0.4}/.gitignore +0 -0
  20. {sarasa-0.0.2 → sarasa-0.0.4}/LICENSE +0 -0
  21. {sarasa-0.0.2 → sarasa-0.0.4}/main.py +0 -0
  22. {sarasa-0.0.2 → sarasa-0.0.4}/pyproject.toml +0 -0
  23. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/activation_checkpoint.py +0 -0
  24. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/data/__init__.py +0 -0
  25. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/data/hf_datasets.py +0 -0
  26. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/data/tokenizer.py +0 -0
  27. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/metrics.py +0 -0
  28. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/optimizers/__init__.py +0 -0
  29. {sarasa-0.0.2 → sarasa-0.0.4}/sarasa/optimizers/utils.py +0 -0
  30. {sarasa-0.0.2 → sarasa-0.0.4}/tests/test_model.py +0 -0
  31. {sarasa-0.0.2 → sarasa-0.0.4}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sarasa
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Add your description here
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.13
@@ -46,6 +46,8 @@ uv add sarasa[cpu|cu128|cu130]
46
46
  - Async distributed checkpoint saving
47
47
 
48
48
  - [ ] Checkpoint loading
49
+ - [ ] FP8 training
50
+ - [ ] Profiling
49
51
 
50
52
  ## Usage
51
53
 
@@ -100,18 +102,22 @@ if __name__ == "__main__":
100
102
  trainer.train()
101
103
  ```
102
104
 
105
+ Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
103
106
  From the command line, you can specify which custom optimizer to use:
104
107
 
105
108
  ```bash
106
109
  python script.py optim:custom_optim --optim.lr 0.001 ...
107
110
  ```
108
111
 
112
+ (As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
113
+
109
114
  ### Config File Example
110
115
 
111
- It's very simple. IDE autocompletion will help you.
116
+ It's very simple.
117
+ IDE autocompletion will help you.
112
118
 
113
119
  ```python
114
- from sarasa.config import Config, Data, LRScheduler, Model, Train, LRScheduler
120
+ from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
115
121
  from custom_optim import CustomOptim
116
122
 
117
123
  # only one Config instance should be defined in each config file
@@ -135,4 +141,4 @@ config = Config.create(
135
141
 
136
142
  ## Acknowledgements
137
143
 
138
- This project is heavily inspired by and borrows code from `torchtitan`.
144
+ This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
@@ -23,6 +23,8 @@ uv add sarasa[cpu|cu128|cu130]
23
23
  - Async distributed checkpoint saving
24
24
 
25
25
  - [ ] Checkpoint loading
26
+ - [ ] FP8 training
27
+ - [ ] Profiling
26
28
 
27
29
  ## Usage
28
30
 
@@ -77,18 +79,22 @@ if __name__ == "__main__":
77
79
  trainer.train()
78
80
  ```
79
81
 
82
+ Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
80
83
  From the command line, you can specify which custom optimizer to use:
81
84
 
82
85
  ```bash
83
86
  python script.py optim:custom_optim --optim.lr 0.001 ...
84
87
  ```
85
88
 
89
+ (As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
90
+
86
91
  ### Config File Example
87
92
 
88
- It's very simple. IDE autocompletion will help you.
93
+ It's very simple.
94
+ IDE autocompletion will help you.
89
95
 
90
96
  ```python
91
- from sarasa.config import Config, Data, LRScheduler, Model, Train, LRScheduler
97
+ from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
92
98
  from custom_optim import CustomOptim
93
99
 
94
100
  # only one Config instance should be defined in each config file
@@ -112,4 +118,4 @@ config = Config.create(
112
118
 
113
119
  ## Acknowledgements
114
120
 
115
- This project is heavily inspired by and borrows code from `torchtitan`.
121
+ This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
@@ -1,11 +1,14 @@
1
1
  from sarasa.config import AdamW, Config, Data, LRScheduler, Model, Train
2
2
 
3
3
  config = Config.create(
4
- model=Model(num_layers=12),
4
+ model=Model(
5
+ name="nanochat_gpt",
6
+ num_layers=12,
7
+ qk_norm=True,
8
+ ),
5
9
  train=Train(
6
10
  local_batch_size=16,
7
11
  global_batch_size=256,
8
- dtype="bfloat16",
9
12
  ),
10
13
  data=Data(tokenizer_path="./tokenizer"),
11
14
  lr_scheduler=LRScheduler(
@@ -2,17 +2,18 @@ from sarasa.config import FSDP, AdamW, Config, Data, LRScheduler, Model, Train
2
2
 
3
3
  config = Config.create(
4
4
  model=Model(
5
+ name="llama3",
5
6
  hidden_dim=2048,
6
7
  num_layers=16,
7
8
  num_heads=32,
8
9
  num_kv_heads=8,
9
10
  head_dim=64,
10
- name="llama3",
11
+ rms_eps=1e-5,
12
+ rms_learnable=True,
11
13
  ),
12
14
  train=Train(
13
15
  local_batch_size=32,
14
- global_batch_size=256,
15
- dtype="bfloat16",
16
+ global_batch_size=1024,
16
17
  use_sac=True,
17
18
  ),
18
19
  data=Data(tokenizer_path="./tokenizer"),
@@ -0,0 +1,11 @@
1
+ from .config import DDP as DDP
2
+ from .config import FSDP as FSDP
3
+ from .config import AdamW as AdamW
4
+ from .config import Checkpoint as Checkpoint
5
+ from .config import Config as Config
6
+ from .config import Data as Data
7
+ from .config import LRScheduler as LRScheduler
8
+ from .config import Metrics as Metrics
9
+ from .config import Model as Model
10
+ from .config import Train as Train
11
+ from .train import Trainer as Trainer
@@ -110,3 +110,13 @@ class Checkpointer:
110
110
  def close(self) -> None:
111
111
  if self.stager is not None:
112
112
  self.stager.close()
113
+
114
+ if self.save_future is not None:
115
+ self.save_future.result()
116
+
117
+ if self.pg is not None:
118
+ dist.destroy_process_group(self.pg)
119
+ self.pg = None
120
+
121
+ def __del__(self):
122
+ self.close()
@@ -91,6 +91,10 @@ class Train:
91
91
  grad_clip: float | None = None
92
92
 
93
93
  dtype: Literal["bfloat16", "float32"] = "float32"
94
+ """Dtype used for model initialization"""
95
+
96
+ amp_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
97
+ """Dtype used for automatic mixed precision training"""
94
98
 
95
99
  compile: bool = False
96
100
 
@@ -154,6 +158,12 @@ class FSDP(Distributed):
154
158
  reshard_after_forward: bool = False
155
159
  """Whether to reshard model parameters after each forward pass (FSDP only)."""
156
160
 
161
+ dtype: str | None = None
162
+ """Dtype for FSDP reduce operations. If None, uses train.dtype."""
163
+
164
+ amp_dtype: str | None = None
165
+ """Dtype for FSDP parameter storage. If None, uses train.amp_dtype."""
166
+
157
167
 
158
168
  @dataclasses.dataclass
159
169
  class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
@@ -183,11 +193,15 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
183
193
  if self.output_dir is not None:
184
194
  self.output_dir.mkdir(parents=True, exist_ok=True)
185
195
 
186
- if hasattr(self.model, "seq_len") and self.model.seq_len is None:
187
- if self.data.seq_len is not None:
196
+ if hasattr(self.model, "seq_len"):
197
+ if self.model.seq_len is None and self.data.seq_len is not None:
188
198
  self.model.seq_len = self.data.seq_len
189
- else:
190
- raise ValueError("Either model.seq_len or data.seq_len must be set.")
199
+ if self.model.seq_len is None:
200
+ raise ValueError("seq_len must be specified in either model or data configuration.")
201
+
202
+ if isinstance(self.distributed, FSDP):
203
+ self.distributed.dtype = self.distributed.dtype or self.train.dtype
204
+ self.distributed.amp_dtype = self.distributed.amp_dtype or self.train.amp_dtype
191
205
 
192
206
  @classmethod
193
207
  def create(
@@ -227,6 +241,8 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
227
241
 
228
242
  import tyro
229
243
 
244
+ from sarasa.utils import rank
245
+
230
246
  loaded_config = None
231
247
 
232
248
  if (under := ("--config_file" in sys.argv)) or ("--config-file" in sys.argv):
@@ -262,6 +278,7 @@ class Config[ModelT, OptimizerT, LRSchedulerT, DataT]:
262
278
  data_type,
263
279
  ],
264
280
  default=loaded_config,
281
+ console_outputs=(rank() == 0),
265
282
  )
266
283
 
267
284
 
@@ -21,6 +21,8 @@ class ModelConfig:
21
21
  vocab_size: int | None = None # set later based on tokenizer
22
22
  seq_len: int | None = None # set later based on data config
23
23
  qk_norm: bool = False # whether to use RMSNorm on q/k
24
+ rms_eps: float | None = None # epsilon for RMSNorm, default to library default if None
25
+ rms_learnable: bool = False # whether RMSNorm has learnable scale parameter
24
26
 
25
27
  def __post_init__(self):
26
28
  # infer hidden_dim, num_heads, num_kv_heads if not provided using the rules presented in nanochat
@@ -3,7 +3,7 @@ from torch import nn
3
3
  from torch.nn import functional as F
4
4
 
5
5
  from sarasa.models import ModelConfig
6
- from sarasa.models.utils import RMSNorm, RoPE
6
+ from sarasa.models.utils import RoPE
7
7
 
8
8
 
9
9
  class SDPAttention(nn.Module):
@@ -57,7 +57,11 @@ class CausalSelfAttention(nn.Module):
57
57
  self.c_k = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
58
58
  self.c_v = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
59
59
  self.c_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
60
- self.qk_norm = RMSNorm(self.head_dim) if config.qk_norm else nn.Identity()
60
+ self.qk_norm = (
61
+ nn.RMSNorm(self.head_dim, eps=config.rms_eps, elementwise_affine=config.rms_learnable)
62
+ if config.qk_norm
63
+ else nn.Identity()
64
+ )
61
65
 
62
66
  # todo: support varlen etc and kv caching
63
67
  self.attn = SDPAttention(is_causal=True, enable_gqa=self.num_heads != self.num_kv_heads)
@@ -4,7 +4,7 @@ from torch.nn import functional as F
4
4
 
5
5
  from sarasa.models import BaseModel, ModelConfig
6
6
  from sarasa.models.attention import CausalSelfAttention
7
- from sarasa.models.utils import RMSNorm, RoPE
7
+ from sarasa.models.utils import RoPE
8
8
 
9
9
 
10
10
  class MLP(nn.Module):
@@ -41,15 +41,16 @@ class Block(nn.Module):
41
41
  self.layer_idx = layer_idx
42
42
  self.attention = CausalSelfAttention(config)
43
43
  self.mlp = MLP(config, multiple_of, ffn_dim_multiplier)
44
- self.norm = RMSNorm(config.hidden_dim)
44
+ self.attn_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
45
+ self.mlp_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
45
46
 
46
47
  def forward(
47
48
  self,
48
49
  x: torch.Tensor,
49
50
  cos_sin: tuple[torch.Tensor, torch.Tensor],
50
51
  ) -> torch.Tensor:
51
- x = x + self.attention(self.norm(x), cos_sin)
52
- x = x + self.mlp(self.norm(x))
52
+ x = x + self.attention(self.attn_norm(x), cos_sin)
53
+ x = x + self.mlp(self.mlp_norm(x))
53
54
  return x
54
55
 
55
56
 
@@ -71,7 +72,7 @@ class Llama3(BaseModel):
71
72
  self.blocks = nn.ModuleList([
72
73
  Block(config, layer_idx, multiple_of, ffn_dim_multiplier) for layer_idx in range(config.num_layers)
73
74
  ])
74
- self.norm = RMSNorm(config.hidden_dim)
75
+ self.norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
75
76
  self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
76
77
 
77
78
  @torch.no_grad()
@@ -101,16 +102,26 @@ class Llama3(BaseModel):
101
102
  b=cutoff_factor * final_out_std,
102
103
  )
103
104
 
105
+ for mod in self.modules():
106
+ if isinstance(mod, nn.RMSNorm):
107
+ mod.reset_parameters()
108
+
104
109
  def param_groups(self) -> dict[str, list[nn.Parameter]]:
105
- matrix_params = list(self.blocks.parameters())
110
+ matrix_params = [param for param in self.blocks.parameters() if param.ndim == 2]
106
111
  embedding_params = list(self.token_emb.parameters())
107
112
  lm_head_params = list(self.output.parameters())
108
- assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + len(lm_head_params))
113
+ rms_norm_params = [
114
+ mod.weight for mod in self.modules() if isinstance(mod, nn.RMSNorm) and mod.elementwise_affine
115
+ ]
116
+ assert len(list(self.parameters())) == (
117
+ len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(rms_norm_params)
118
+ )
109
119
 
110
120
  return {
111
121
  "matrix": matrix_params,
112
122
  "embedding": embedding_params,
113
123
  "lm_head": lm_head_params,
124
+ "rms_norm": rms_norm_params,
114
125
  }
115
126
 
116
127
  def forward(
@@ -8,7 +8,16 @@ from torch.nn import functional as F
8
8
 
9
9
  from sarasa.models import BaseModel, ModelConfig
10
10
  from sarasa.models.attention import CausalSelfAttention
11
- from sarasa.models.utils import RMSNorm, RoPE
11
+ from sarasa.models.utils import RoPE
12
+
13
+
14
+ class RMSNorm(torch.nn.RMSNorm):
15
+ # RMSNorm without affine parameters
16
+ def __init__(
17
+ self,
18
+ normalized_shape: int,
19
+ ):
20
+ super().__init__(normalized_shape, eps=None, elementwise_affine=False)
12
21
 
13
22
 
14
23
  class MLP(nn.Module):
@@ -1,15 +1,6 @@
1
1
  import torch
2
2
 
3
3
 
4
- class RMSNorm(torch.nn.RMSNorm):
5
- # RMSNorm without affine parameters
6
- def __init__(
7
- self,
8
- normalized_shape: int,
9
- ):
10
- super().__init__(normalized_shape, eps=None, elementwise_affine=False)
11
-
12
-
13
4
  class RoPE:
14
5
  @staticmethod
15
6
  def precompute(
@@ -7,6 +7,7 @@ import torch
7
7
  import torch.distributed as dist
8
8
  from loguru import logger
9
9
  from torch.distributed.elastic.multiprocessing.errors import record
10
+ from torch.nn import functional as F
10
11
 
11
12
  from sarasa.activation_checkpoint import apply_op_sac
12
13
  from sarasa.checkpoint import Checkpointer
@@ -48,9 +49,6 @@ class Trainer:
48
49
  vocab_size = len(self.tokenizer)
49
50
  self.config.model.vocab_size = vocab_size
50
51
 
51
- # todo: support other loss functions
52
- self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction="sum")
53
-
54
52
  # setup model, optimizer, lr scheduler
55
53
  with torch.device("meta"), set_dtype(getattr(torch, config.train.dtype)):
56
54
  self.model = self.config.model.create()
@@ -68,9 +66,9 @@ class Trainer:
68
66
  if config.train.compile:
69
67
  logger.info("Compiling the model")
70
68
  for block in self.model.blocks:
71
- block.compile(fullgraph=True)
69
+ block.compile(fullgraph=True, dynamic=False)
72
70
  self.model.compile(dynamic=False)
73
- self.loss_fn.compile()
71
+ self.loss_fn = torch.compile(self.loss_fn, fullgraph=True, dynamic=False)
74
72
 
75
73
  if world_size() > 1:
76
74
  apply_distributed(
@@ -102,7 +100,10 @@ class Trainer:
102
100
 
103
101
  self.amp_context = contextlib.nullcontext()
104
102
  if config.distributed.name != "fsdp":
105
- self.amp_context = torch.autocast(device_type=self.device.type, dtype=getattr(torch, config.train.dtype))
103
+ self.amp_context = torch.autocast(
104
+ device_type=self.device.type,
105
+ dtype=getattr(torch, config.train.amp_dtype),
106
+ )
106
107
 
107
108
  # todo: setup profiler context
108
109
  self.profile_context = contextlib.nullcontext()
@@ -116,40 +117,36 @@ class Trainer:
116
117
  f"Failed to activate FA4 flash attention: {e}. Install sarasa with `flash_attn` extra for better performance."
117
118
  )
118
119
 
119
- def __del__(self) -> None:
120
- # cleanup distributed
121
- if world_size() > 1:
122
- try:
123
- dist.destroy_process_group()
124
- except Exception as e:
125
- logger.warning(f"Failed to destroy process group: {e}")
126
-
127
120
  @record
128
121
  def train(self):
129
- logger.info("Starting training...")
130
-
131
- self.model.train()
132
- with self.profile_context:
133
- data_iter = self.batch_generator(self.data_loader)
134
- for _ in range(self.config.train.steps):
135
- self.step += 1
136
- self.gc.collect(self.step)
137
- try:
138
- self.train_step(data_iter)
139
- except StopIteration:
140
- logger.warning("Data loader exhausted during training.")
141
- break
142
-
143
- if self.checkpointer is not None:
144
- self.checkpointer.save(self.step)
145
-
146
- if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
147
- self.evaluate()
148
-
149
- if world_size() > 1 and self.step == 1:
150
- update_timeout(self.config.distributed.train_timeout_seconds, self.device)
151
-
152
- logger.info("Training completed.")
122
+ try:
123
+ logger.info("Starting training...")
124
+
125
+ self.model.train()
126
+ with self.profile_context:
127
+ data_iter = self.batch_generator(self.data_loader)
128
+ for _ in range(self.config.train.steps):
129
+ self.step += 1
130
+ self.gc.collect(self.step)
131
+ try:
132
+ self.train_step(data_iter)
133
+ except StopIteration:
134
+ logger.warning("Data loader exhausted during training.")
135
+ break
136
+
137
+ if self.checkpointer is not None:
138
+ self.checkpointer.save(self.step)
139
+
140
+ if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
141
+ self.evaluate()
142
+
143
+ if world_size() > 1 and self.step == 1:
144
+ update_timeout(self.config.distributed.train_timeout_seconds, self.device)
145
+
146
+ logger.info("Training completed.")
147
+ finally:
148
+ logger.info("Cleaning up trainer...")
149
+ self.close()
153
150
 
154
151
  def batch_generator(
155
152
  self,
@@ -190,7 +187,7 @@ class Trainer:
190
187
 
191
188
  with self.amp_context:
192
189
  pred = self.model(**input_dict)
193
- loss = self.loss_fn(pred.flatten(0, 1), target.flatten(0, 1)) / valid_tokens
190
+ loss = self.loss_fn(pred, target) / valid_tokens
194
191
 
195
192
  del pred
196
193
  loss.backward()
@@ -234,6 +231,18 @@ class Trainer:
234
231
  },
235
232
  )
236
233
 
234
+ def loss_fn(
235
+ self,
236
+ pred: torch.Tensor,
237
+ target: torch.Tensor,
238
+ ) -> torch.Tensor:
239
+ return F.cross_entropy(
240
+ pred.flatten(0, 1).float(),
241
+ target.flatten(0, 1),
242
+ ignore_index=IGNORE_INDEX,
243
+ reduction="sum",
244
+ )
245
+
237
246
  def evaluate(self):
238
247
  raise NotImplementedError
239
248
 
@@ -242,3 +251,20 @@ class Trainer:
242
251
  batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
243
252
  ) -> None:
244
253
  raise NotImplementedError
254
+
255
+ def close(self) -> None:
256
+ if self.checkpointer is not None:
257
+ self.checkpointer.close()
258
+
259
+ if self.metrics_processor is not None:
260
+ self.metrics_processor.close()
261
+
262
+ # cleanup distributed
263
+ if world_size() > 1:
264
+ try:
265
+ dist.destroy_process_group()
266
+ except Exception as e:
267
+ logger.warning(f"Failed to destroy process group: {e}")
268
+
269
+ def __del__(self):
270
+ self.close()
@@ -3,6 +3,7 @@ import gc
3
3
  import os
4
4
  import sys
5
5
  import time
6
+ import typing
6
7
  from datetime import timedelta
7
8
  from functools import cache
8
9
 
@@ -11,8 +12,11 @@ from loguru import logger
11
12
  from torch import distributed as dist
12
13
  from torch import nn
13
14
 
15
+ if typing.TYPE_CHECKING:
16
+ from sarasa.config import Config, Distributed
14
17
 
15
- def setup_logger(config) -> None:
18
+
19
+ def setup_logger(config: Config) -> None:
16
20
  logger.remove()
17
21
  if config.debug:
18
22
  logger_format = f"<blue>RANK={rank()}</blue> | " + (
@@ -128,7 +132,7 @@ def update_timeout(
128
132
 
129
133
 
130
134
  def apply_distributed(
131
- config,
135
+ config: Distributed,
132
136
  model: nn.Module,
133
137
  device: torch.device,
134
138
  compile: bool,
@@ -149,8 +153,8 @@ def apply_distributed(
149
153
 
150
154
  # todo: make dtypes configurable
151
155
  mp_policy = MixedPrecisionPolicy(
152
- param_dtype=torch.bfloat16,
153
- reduce_dtype=torch.float32,
156
+ param_dtype=getattr(torch, config.amp_dtype),
157
+ reduce_dtype=getattr(torch, config.dtype),
154
158
  )
155
159
 
156
160
  for block in model.blocks:
@@ -73,3 +73,10 @@ def test_config_loading_filetype_error(monkeypatch, tmp_path):
73
73
  monkeypatch.setattr(sys, "argv", ["program", "--config_file", str(config_file)])
74
74
  with pytest.raises(ValueError):
75
75
  Config.from_cli()
76
+
77
+
78
+ def test_config_post_init(monkeypatch):
79
+ monkeypatch.setattr(sys, "argv", ["program", "distributed:fsdp"])
80
+ config = Config.from_cli() # just check no error is raised
81
+ assert config.distributed.dtype == config.train.dtype
82
+ assert config.distributed.amp_dtype == config.train.amp_dtype
@@ -1,2 +0,0 @@
1
- from .config import Config as Config
2
- from .trainer import Trainer as Trainer
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes