sarasa 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {sarasa-0.0.3 → sarasa-0.0.4}/PKG-INFO +10 -4
  2. {sarasa-0.0.3 → sarasa-0.0.4}/README.md +9 -3
  3. {sarasa-0.0.3 → sarasa-0.0.4}/configs/example.py +5 -2
  4. {sarasa-0.0.3 → sarasa-0.0.4}/configs/llama3-1b.py +4 -3
  5. sarasa-0.0.4/sarasa/__init__.py +11 -0
  6. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/checkpoint.py +10 -0
  7. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/models/__init__.py +2 -0
  8. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/models/attention.py +6 -2
  9. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/models/llama3.py +18 -7
  10. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/models/nanochat_gpt.py +10 -1
  11. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/models/utils.py +0 -9
  12. sarasa-0.0.3/sarasa/trainer.py → sarasa-0.0.4/sarasa/train.py +26 -14
  13. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/utils.py +6 -2
  14. sarasa-0.0.3/sarasa/__init__.py +0 -2
  15. {sarasa-0.0.3 → sarasa-0.0.4}/.github/workflows/pypi.yaml +0 -0
  16. {sarasa-0.0.3 → sarasa-0.0.4}/.github/workflows/tests_and_lint.yaml +0 -0
  17. {sarasa-0.0.3 → sarasa-0.0.4}/.gitignore +0 -0
  18. {sarasa-0.0.3 → sarasa-0.0.4}/LICENSE +0 -0
  19. {sarasa-0.0.3 → sarasa-0.0.4}/main.py +0 -0
  20. {sarasa-0.0.3 → sarasa-0.0.4}/pyproject.toml +0 -0
  21. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/activation_checkpoint.py +0 -0
  22. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/config.py +0 -0
  23. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/data/__init__.py +0 -0
  24. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/data/hf_datasets.py +0 -0
  25. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/data/tokenizer.py +0 -0
  26. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/metrics.py +0 -0
  27. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/optimizers/__init__.py +0 -0
  28. {sarasa-0.0.3 → sarasa-0.0.4}/sarasa/optimizers/utils.py +0 -0
  29. {sarasa-0.0.3 → sarasa-0.0.4}/tests/test_config.py +0 -0
  30. {sarasa-0.0.3 → sarasa-0.0.4}/tests/test_model.py +0 -0
  31. {sarasa-0.0.3 → sarasa-0.0.4}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sarasa
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Add your description here
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.13
@@ -46,6 +46,8 @@ uv add sarasa[cpu|cu128|cu130]
46
46
  - Async distributed checkpoint saving
47
47
 
48
48
  - [ ] Checkpoint loading
49
+ - [ ] FP8 training
50
+ - [ ] Profiling
49
51
 
50
52
  ## Usage
51
53
 
@@ -100,18 +102,22 @@ if __name__ == "__main__":
100
102
  trainer.train()
101
103
  ```
102
104
 
105
+ Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
103
106
  From the command line, you can specify which custom optimizer to use:
104
107
 
105
108
  ```bash
106
109
  python script.py optim:custom_optim --optim.lr 0.001 ...
107
110
  ```
108
111
 
112
+ (As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
113
+
109
114
  ### Config File Example
110
115
 
111
- It's very simple. IDE autocompletion will help you.
116
+ It's very simple.
117
+ IDE autocompletion will help you.
112
118
 
113
119
  ```python
114
- from sarasa.config import Config, Data, LRScheduler, Model, Train, LRScheduler
120
+ from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
115
121
  from custom_optim import CustomOptim
116
122
 
117
123
  # only one Config instance should be defined in each config file
@@ -135,4 +141,4 @@ config = Config.create(
135
141
 
136
142
  ## Acknowledgements
137
143
 
138
- This project is heavily inspired by and borrows code from `torchtitan`.
144
+ This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
@@ -23,6 +23,8 @@ uv add sarasa[cpu|cu128|cu130]
23
23
  - Async distributed checkpoint saving
24
24
 
25
25
  - [ ] Checkpoint loading
26
+ - [ ] FP8 training
27
+ - [ ] Profiling
26
28
 
27
29
  ## Usage
28
30
 
@@ -77,18 +79,22 @@ if __name__ == "__main__":
77
79
  trainer.train()
78
80
  ```
79
81
 
82
+ Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
80
83
  From the command line, you can specify which custom optimizer to use:
81
84
 
82
85
  ```bash
83
86
  python script.py optim:custom_optim --optim.lr 0.001 ...
84
87
  ```
85
88
 
89
+ (As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
90
+
86
91
  ### Config File Example
87
92
 
88
- It's very simple. IDE autocompletion will help you.
93
+ It's very simple.
94
+ IDE autocompletion will help you.
89
95
 
90
96
  ```python
91
- from sarasa.config import Config, Data, LRScheduler, Model, Train, LRScheduler
97
+ from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
92
98
  from custom_optim import CustomOptim
93
99
 
94
100
  # only one Config instance should be defined in each config file
@@ -112,4 +118,4 @@ config = Config.create(
112
118
 
113
119
  ## Acknowledgements
114
120
 
115
- This project is heavily inspired by and borrows code from `torchtitan`.
121
+ This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
@@ -1,11 +1,14 @@
1
1
  from sarasa.config import AdamW, Config, Data, LRScheduler, Model, Train
2
2
 
3
3
  config = Config.create(
4
- model=Model(num_layers=12),
4
+ model=Model(
5
+ name="nanochat_gpt",
6
+ num_layers=12,
7
+ qk_norm=True,
8
+ ),
5
9
  train=Train(
6
10
  local_batch_size=16,
7
11
  global_batch_size=256,
8
- dtype="bfloat16",
9
12
  ),
10
13
  data=Data(tokenizer_path="./tokenizer"),
11
14
  lr_scheduler=LRScheduler(
@@ -2,17 +2,18 @@ from sarasa.config import FSDP, AdamW, Config, Data, LRScheduler, Model, Train
2
2
 
3
3
  config = Config.create(
4
4
  model=Model(
5
+ name="llama3",
5
6
  hidden_dim=2048,
6
7
  num_layers=16,
7
8
  num_heads=32,
8
9
  num_kv_heads=8,
9
10
  head_dim=64,
10
- name="llama3",
11
+ rms_eps=1e-5,
12
+ rms_learnable=True,
11
13
  ),
12
14
  train=Train(
13
15
  local_batch_size=32,
14
- global_batch_size=256,
15
- dtype="bfloat16",
16
+ global_batch_size=1024,
16
17
  use_sac=True,
17
18
  ),
18
19
  data=Data(tokenizer_path="./tokenizer"),
@@ -0,0 +1,11 @@
1
+ from .config import DDP as DDP
2
+ from .config import FSDP as FSDP
3
+ from .config import AdamW as AdamW
4
+ from .config import Checkpoint as Checkpoint
5
+ from .config import Config as Config
6
+ from .config import Data as Data
7
+ from .config import LRScheduler as LRScheduler
8
+ from .config import Metrics as Metrics
9
+ from .config import Model as Model
10
+ from .config import Train as Train
11
+ from .train import Trainer as Trainer
@@ -110,3 +110,13 @@ class Checkpointer:
110
110
  def close(self) -> None:
111
111
  if self.stager is not None:
112
112
  self.stager.close()
113
+
114
+ if self.save_future is not None:
115
+ self.save_future.result()
116
+
117
+ if self.pg is not None:
118
+ dist.destroy_process_group(self.pg)
119
+ self.pg = None
120
+
121
+ def __del__(self):
122
+ self.close()
@@ -21,6 +21,8 @@ class ModelConfig:
21
21
  vocab_size: int | None = None # set later based on tokenizer
22
22
  seq_len: int | None = None # set later based on data config
23
23
  qk_norm: bool = False # whether to use RMSNorm on q/k
24
+ rms_eps: float | None = None # epsilon for RMSNorm, default to library default if None
25
+ rms_learnable: bool = False # whether RMSNorm has learnable scale parameter
24
26
 
25
27
  def __post_init__(self):
26
28
  # infer hidden_dim, num_heads, num_kv_heads if not provided using the rules presented in nanochat
@@ -3,7 +3,7 @@ from torch import nn
3
3
  from torch.nn import functional as F
4
4
 
5
5
  from sarasa.models import ModelConfig
6
- from sarasa.models.utils import RMSNorm, RoPE
6
+ from sarasa.models.utils import RoPE
7
7
 
8
8
 
9
9
  class SDPAttention(nn.Module):
@@ -57,7 +57,11 @@ class CausalSelfAttention(nn.Module):
57
57
  self.c_k = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
58
58
  self.c_v = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
59
59
  self.c_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
60
- self.qk_norm = RMSNorm(self.head_dim) if config.qk_norm else nn.Identity()
60
+ self.qk_norm = (
61
+ nn.RMSNorm(self.head_dim, eps=config.rms_eps, elementwise_affine=config.rms_learnable)
62
+ if config.qk_norm
63
+ else nn.Identity()
64
+ )
61
65
 
62
66
  # todo: support varlen etc and kv caching
63
67
  self.attn = SDPAttention(is_causal=True, enable_gqa=self.num_heads != self.num_kv_heads)
@@ -4,7 +4,7 @@ from torch.nn import functional as F
4
4
 
5
5
  from sarasa.models import BaseModel, ModelConfig
6
6
  from sarasa.models.attention import CausalSelfAttention
7
- from sarasa.models.utils import RMSNorm, RoPE
7
+ from sarasa.models.utils import RoPE
8
8
 
9
9
 
10
10
  class MLP(nn.Module):
@@ -41,15 +41,16 @@ class Block(nn.Module):
41
41
  self.layer_idx = layer_idx
42
42
  self.attention = CausalSelfAttention(config)
43
43
  self.mlp = MLP(config, multiple_of, ffn_dim_multiplier)
44
- self.norm = RMSNorm(config.hidden_dim)
44
+ self.attn_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
45
+ self.mlp_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
45
46
 
46
47
  def forward(
47
48
  self,
48
49
  x: torch.Tensor,
49
50
  cos_sin: tuple[torch.Tensor, torch.Tensor],
50
51
  ) -> torch.Tensor:
51
- x = x + self.attention(self.norm(x), cos_sin)
52
- x = x + self.mlp(self.norm(x))
52
+ x = x + self.attention(self.attn_norm(x), cos_sin)
53
+ x = x + self.mlp(self.mlp_norm(x))
53
54
  return x
54
55
 
55
56
 
@@ -71,7 +72,7 @@ class Llama3(BaseModel):
71
72
  self.blocks = nn.ModuleList([
72
73
  Block(config, layer_idx, multiple_of, ffn_dim_multiplier) for layer_idx in range(config.num_layers)
73
74
  ])
74
- self.norm = RMSNorm(config.hidden_dim)
75
+ self.norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
75
76
  self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
76
77
 
77
78
  @torch.no_grad()
@@ -101,16 +102,26 @@ class Llama3(BaseModel):
101
102
  b=cutoff_factor * final_out_std,
102
103
  )
103
104
 
105
+ for mod in self.modules():
106
+ if isinstance(mod, nn.RMSNorm):
107
+ mod.reset_parameters()
108
+
104
109
  def param_groups(self) -> dict[str, list[nn.Parameter]]:
105
- matrix_params = list(self.blocks.parameters())
110
+ matrix_params = [param for param in self.blocks.parameters() if param.ndim == 2]
106
111
  embedding_params = list(self.token_emb.parameters())
107
112
  lm_head_params = list(self.output.parameters())
108
- assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + len(lm_head_params))
113
+ rms_norm_params = [
114
+ mod.weight for mod in self.modules() if isinstance(mod, nn.RMSNorm) and mod.elementwise_affine
115
+ ]
116
+ assert len(list(self.parameters())) == (
117
+ len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(rms_norm_params)
118
+ )
109
119
 
110
120
  return {
111
121
  "matrix": matrix_params,
112
122
  "embedding": embedding_params,
113
123
  "lm_head": lm_head_params,
124
+ "rms_norm": rms_norm_params,
114
125
  }
115
126
 
116
127
  def forward(
@@ -8,7 +8,16 @@ from torch.nn import functional as F
8
8
 
9
9
  from sarasa.models import BaseModel, ModelConfig
10
10
  from sarasa.models.attention import CausalSelfAttention
11
- from sarasa.models.utils import RMSNorm, RoPE
11
+ from sarasa.models.utils import RoPE
12
+
13
+
14
+ class RMSNorm(torch.nn.RMSNorm):
15
+ # RMSNorm without affine parameters
16
+ def __init__(
17
+ self,
18
+ normalized_shape: int,
19
+ ):
20
+ super().__init__(normalized_shape, eps=None, elementwise_affine=False)
12
21
 
13
22
 
14
23
  class MLP(nn.Module):
@@ -1,15 +1,6 @@
1
1
  import torch
2
2
 
3
3
 
4
- class RMSNorm(torch.nn.RMSNorm):
5
- # RMSNorm without affine parameters
6
- def __init__(
7
- self,
8
- normalized_shape: int,
9
- ):
10
- super().__init__(normalized_shape, eps=None, elementwise_affine=False)
11
-
12
-
13
4
  class RoPE:
14
5
  @staticmethod
15
6
  def precompute(
@@ -7,6 +7,7 @@ import torch
7
7
  import torch.distributed as dist
8
8
  from loguru import logger
9
9
  from torch.distributed.elastic.multiprocessing.errors import record
10
+ from torch.nn import functional as F
10
11
 
11
12
  from sarasa.activation_checkpoint import apply_op_sac
12
13
  from sarasa.checkpoint import Checkpointer
@@ -48,9 +49,6 @@ class Trainer:
48
49
  vocab_size = len(self.tokenizer)
49
50
  self.config.model.vocab_size = vocab_size
50
51
 
51
- # todo: support other loss functions
52
- self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction="sum")
53
-
54
52
  # setup model, optimizer, lr scheduler
55
53
  with torch.device("meta"), set_dtype(getattr(torch, config.train.dtype)):
56
54
  self.model = self.config.model.create()
@@ -68,9 +66,9 @@ class Trainer:
68
66
  if config.train.compile:
69
67
  logger.info("Compiling the model")
70
68
  for block in self.model.blocks:
71
- block.compile(fullgraph=True)
69
+ block.compile(fullgraph=True, dynamic=False)
72
70
  self.model.compile(dynamic=False)
73
- self.loss_fn.compile()
71
+ self.loss_fn = torch.compile(self.loss_fn, fullgraph=True, dynamic=False)
74
72
 
75
73
  if world_size() > 1:
76
74
  apply_distributed(
@@ -119,14 +117,6 @@ class Trainer:
119
117
  f"Failed to activate FA4 flash attention: {e}. Install sarasa with `flash_attn` extra for better performance."
120
118
  )
121
119
 
122
- def __del__(self) -> None:
123
- # cleanup distributed
124
- if world_size() > 1:
125
- try:
126
- dist.destroy_process_group()
127
- except Exception as e:
128
- logger.warning(f"Failed to destroy process group: {e}")
129
-
130
120
  @record
131
121
  def train(self):
132
122
  try:
@@ -197,7 +187,7 @@ class Trainer:
197
187
 
198
188
  with self.amp_context:
199
189
  pred = self.model(**input_dict)
200
- loss = self.loss_fn(pred.flatten(0, 1), target.flatten(0, 1)) / valid_tokens
190
+ loss = self.loss_fn(pred, target) / valid_tokens
201
191
 
202
192
  del pred
203
193
  loss.backward()
@@ -241,6 +231,18 @@ class Trainer:
241
231
  },
242
232
  )
243
233
 
234
+ def loss_fn(
235
+ self,
236
+ pred: torch.Tensor,
237
+ target: torch.Tensor,
238
+ ) -> torch.Tensor:
239
+ return F.cross_entropy(
240
+ pred.flatten(0, 1).float(),
241
+ target.flatten(0, 1),
242
+ ignore_index=IGNORE_INDEX,
243
+ reduction="sum",
244
+ )
245
+
244
246
  def evaluate(self):
245
247
  raise NotImplementedError
246
248
 
@@ -256,3 +258,13 @@ class Trainer:
256
258
 
257
259
  if self.metrics_processor is not None:
258
260
  self.metrics_processor.close()
261
+
262
+ # cleanup distributed
263
+ if world_size() > 1:
264
+ try:
265
+ dist.destroy_process_group()
266
+ except Exception as e:
267
+ logger.warning(f"Failed to destroy process group: {e}")
268
+
269
+ def __del__(self):
270
+ self.close()
@@ -3,6 +3,7 @@ import gc
3
3
  import os
4
4
  import sys
5
5
  import time
6
+ import typing
6
7
  from datetime import timedelta
7
8
  from functools import cache
8
9
 
@@ -11,8 +12,11 @@ from loguru import logger
11
12
  from torch import distributed as dist
12
13
  from torch import nn
13
14
 
15
+ if typing.TYPE_CHECKING:
16
+ from sarasa.config import Config, Distributed
14
17
 
15
- def setup_logger(config) -> None:
18
+
19
+ def setup_logger(config: Config) -> None:
16
20
  logger.remove()
17
21
  if config.debug:
18
22
  logger_format = f"<blue>RANK={rank()}</blue> | " + (
@@ -128,7 +132,7 @@ def update_timeout(
128
132
 
129
133
 
130
134
  def apply_distributed(
131
- config,
135
+ config: Distributed,
132
136
  model: nn.Module,
133
137
  device: torch.device,
134
138
  compile: bool,
@@ -1,2 +0,0 @@
1
- from .config import Config as Config
2
- from .trainer import Trainer as Trainer
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes