sarasa 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sarasa/__init__.py CHANGED
@@ -1,2 +1,11 @@
1
+ from .config import DDP as DDP
2
+ from .config import FSDP as FSDP
3
+ from .config import AdamW as AdamW
4
+ from .config import Checkpoint as Checkpoint
1
5
  from .config import Config as Config
2
- from .trainer import Trainer as Trainer
6
+ from .config import Data as Data
7
+ from .config import LRScheduler as LRScheduler
8
+ from .config import Metrics as Metrics
9
+ from .config import Model as Model
10
+ from .config import Train as Train
11
+ from .train import Trainer as Trainer
sarasa/checkpoint.py CHANGED
@@ -110,3 +110,13 @@ class Checkpointer:
110
110
  def close(self) -> None:
111
111
  if self.stager is not None:
112
112
  self.stager.close()
113
+
114
+ if self.save_future is not None:
115
+ self.save_future.result()
116
+
117
+ if self.pg is not None:
118
+ dist.destroy_process_group(self.pg)
119
+ self.pg = None
120
+
121
+ def __del__(self):
122
+ self.close()
sarasa/models/__init__.py CHANGED
@@ -21,6 +21,8 @@ class ModelConfig:
21
21
  vocab_size: int | None = None # set later based on tokenizer
22
22
  seq_len: int | None = None # set later based on data config
23
23
  qk_norm: bool = False # whether to use RMSNorm on q/k
24
+ rms_eps: float | None = None # epsilon for RMSNorm, default to library default if None
25
+ rms_learnable: bool = False # whether RMSNorm has learnable scale parameter
24
26
 
25
27
  def __post_init__(self):
26
28
  # infer hidden_dim, num_heads, num_kv_heads if not provided using the rules presented in nanochat
@@ -3,7 +3,7 @@ from torch import nn
3
3
  from torch.nn import functional as F
4
4
 
5
5
  from sarasa.models import ModelConfig
6
- from sarasa.models.utils import RMSNorm, RoPE
6
+ from sarasa.models.utils import RoPE
7
7
 
8
8
 
9
9
  class SDPAttention(nn.Module):
@@ -57,7 +57,11 @@ class CausalSelfAttention(nn.Module):
57
57
  self.c_k = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
58
58
  self.c_v = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
59
59
  self.c_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
60
- self.qk_norm = RMSNorm(self.head_dim) if config.qk_norm else nn.Identity()
60
+ self.qk_norm = (
61
+ nn.RMSNorm(self.head_dim, eps=config.rms_eps, elementwise_affine=config.rms_learnable)
62
+ if config.qk_norm
63
+ else nn.Identity()
64
+ )
61
65
 
62
66
  # todo: support varlen etc and kv caching
63
67
  self.attn = SDPAttention(is_causal=True, enable_gqa=self.num_heads != self.num_kv_heads)
sarasa/models/llama3.py CHANGED
@@ -4,7 +4,7 @@ from torch.nn import functional as F
4
4
 
5
5
  from sarasa.models import BaseModel, ModelConfig
6
6
  from sarasa.models.attention import CausalSelfAttention
7
- from sarasa.models.utils import RMSNorm, RoPE
7
+ from sarasa.models.utils import RoPE
8
8
 
9
9
 
10
10
  class MLP(nn.Module):
@@ -41,15 +41,16 @@ class Block(nn.Module):
41
41
  self.layer_idx = layer_idx
42
42
  self.attention = CausalSelfAttention(config)
43
43
  self.mlp = MLP(config, multiple_of, ffn_dim_multiplier)
44
- self.norm = RMSNorm(config.hidden_dim)
44
+ self.attn_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
45
+ self.mlp_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
45
46
 
46
47
  def forward(
47
48
  self,
48
49
  x: torch.Tensor,
49
50
  cos_sin: tuple[torch.Tensor, torch.Tensor],
50
51
  ) -> torch.Tensor:
51
- x = x + self.attention(self.norm(x), cos_sin)
52
- x = x + self.mlp(self.norm(x))
52
+ x = x + self.attention(self.attn_norm(x), cos_sin)
53
+ x = x + self.mlp(self.mlp_norm(x))
53
54
  return x
54
55
 
55
56
 
@@ -71,7 +72,7 @@ class Llama3(BaseModel):
71
72
  self.blocks = nn.ModuleList([
72
73
  Block(config, layer_idx, multiple_of, ffn_dim_multiplier) for layer_idx in range(config.num_layers)
73
74
  ])
74
- self.norm = RMSNorm(config.hidden_dim)
75
+ self.norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
75
76
  self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
76
77
 
77
78
  @torch.no_grad()
@@ -101,16 +102,26 @@ class Llama3(BaseModel):
101
102
  b=cutoff_factor * final_out_std,
102
103
  )
103
104
 
105
+ for mod in self.modules():
106
+ if isinstance(mod, nn.RMSNorm):
107
+ mod.reset_parameters()
108
+
104
109
  def param_groups(self) -> dict[str, list[nn.Parameter]]:
105
- matrix_params = list(self.blocks.parameters())
110
+ matrix_params = [param for param in self.blocks.parameters() if param.ndim == 2]
106
111
  embedding_params = list(self.token_emb.parameters())
107
112
  lm_head_params = list(self.output.parameters())
108
- assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + len(lm_head_params))
113
+ rms_norm_params = [
114
+ mod.weight for mod in self.modules() if isinstance(mod, nn.RMSNorm) and mod.elementwise_affine
115
+ ]
116
+ assert len(list(self.parameters())) == (
117
+ len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(rms_norm_params)
118
+ )
109
119
 
110
120
  return {
111
121
  "matrix": matrix_params,
112
122
  "embedding": embedding_params,
113
123
  "lm_head": lm_head_params,
124
+ "rms_norm": rms_norm_params,
114
125
  }
115
126
 
116
127
  def forward(
@@ -8,7 +8,16 @@ from torch.nn import functional as F
8
8
 
9
9
  from sarasa.models import BaseModel, ModelConfig
10
10
  from sarasa.models.attention import CausalSelfAttention
11
- from sarasa.models.utils import RMSNorm, RoPE
11
+ from sarasa.models.utils import RoPE
12
+
13
+
14
+ class RMSNorm(torch.nn.RMSNorm):
15
+ # RMSNorm without affine parameters
16
+ def __init__(
17
+ self,
18
+ normalized_shape: int,
19
+ ):
20
+ super().__init__(normalized_shape, eps=None, elementwise_affine=False)
12
21
 
13
22
 
14
23
  class MLP(nn.Module):
sarasa/models/utils.py CHANGED
@@ -1,15 +1,6 @@
1
1
  import torch
2
2
 
3
3
 
4
- class RMSNorm(torch.nn.RMSNorm):
5
- # RMSNorm without affine parameters
6
- def __init__(
7
- self,
8
- normalized_shape: int,
9
- ):
10
- super().__init__(normalized_shape, eps=None, elementwise_affine=False)
11
-
12
-
13
4
  class RoPE:
14
5
  @staticmethod
15
6
  def precompute(
@@ -7,6 +7,7 @@ import torch
7
7
  import torch.distributed as dist
8
8
  from loguru import logger
9
9
  from torch.distributed.elastic.multiprocessing.errors import record
10
+ from torch.nn import functional as F
10
11
 
11
12
  from sarasa.activation_checkpoint import apply_op_sac
12
13
  from sarasa.checkpoint import Checkpointer
@@ -48,9 +49,6 @@ class Trainer:
48
49
  vocab_size = len(self.tokenizer)
49
50
  self.config.model.vocab_size = vocab_size
50
51
 
51
- # todo: support other loss functions
52
- self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction="sum")
53
-
54
52
  # setup model, optimizer, lr scheduler
55
53
  with torch.device("meta"), set_dtype(getattr(torch, config.train.dtype)):
56
54
  self.model = self.config.model.create()
@@ -68,9 +66,9 @@ class Trainer:
68
66
  if config.train.compile:
69
67
  logger.info("Compiling the model")
70
68
  for block in self.model.blocks:
71
- block.compile(fullgraph=True)
69
+ block.compile(fullgraph=True, dynamic=False)
72
70
  self.model.compile(dynamic=False)
73
- self.loss_fn.compile()
71
+ self.loss_fn = torch.compile(self.loss_fn, fullgraph=True, dynamic=False)
74
72
 
75
73
  if world_size() > 1:
76
74
  apply_distributed(
@@ -119,14 +117,6 @@ class Trainer:
119
117
  f"Failed to activate FA4 flash attention: {e}. Install sarasa with `flash_attn` extra for better performance."
120
118
  )
121
119
 
122
- def __del__(self) -> None:
123
- # cleanup distributed
124
- if world_size() > 1:
125
- try:
126
- dist.destroy_process_group()
127
- except Exception as e:
128
- logger.warning(f"Failed to destroy process group: {e}")
129
-
130
120
  @record
131
121
  def train(self):
132
122
  try:
@@ -197,7 +187,7 @@ class Trainer:
197
187
 
198
188
  with self.amp_context:
199
189
  pred = self.model(**input_dict)
200
- loss = self.loss_fn(pred.flatten(0, 1), target.flatten(0, 1)) / valid_tokens
190
+ loss = self.loss_fn(pred, target) / valid_tokens
201
191
 
202
192
  del pred
203
193
  loss.backward()
@@ -241,6 +231,18 @@ class Trainer:
241
231
  },
242
232
  )
243
233
 
234
+ def loss_fn(
235
+ self,
236
+ pred: torch.Tensor,
237
+ target: torch.Tensor,
238
+ ) -> torch.Tensor:
239
+ return F.cross_entropy(
240
+ pred.flatten(0, 1).float(),
241
+ target.flatten(0, 1),
242
+ ignore_index=IGNORE_INDEX,
243
+ reduction="sum",
244
+ )
245
+
244
246
  def evaluate(self):
245
247
  raise NotImplementedError
246
248
 
@@ -256,3 +258,13 @@ class Trainer:
256
258
 
257
259
  if self.metrics_processor is not None:
258
260
  self.metrics_processor.close()
261
+
262
+ # cleanup distributed
263
+ if world_size() > 1:
264
+ try:
265
+ dist.destroy_process_group()
266
+ except Exception as e:
267
+ logger.warning(f"Failed to destroy process group: {e}")
268
+
269
+ def __del__(self):
270
+ self.close()
sarasa/utils.py CHANGED
@@ -3,6 +3,7 @@ import gc
3
3
  import os
4
4
  import sys
5
5
  import time
6
+ import typing
6
7
  from datetime import timedelta
7
8
  from functools import cache
8
9
 
@@ -11,8 +12,11 @@ from loguru import logger
11
12
  from torch import distributed as dist
12
13
  from torch import nn
13
14
 
15
+ if typing.TYPE_CHECKING:
16
+ from sarasa.config import Config, Distributed
14
17
 
15
- def setup_logger(config) -> None:
18
+
19
+ def setup_logger(config: Config) -> None:
16
20
  logger.remove()
17
21
  if config.debug:
18
22
  logger_format = f"<blue>RANK={rank()}</blue> | " + (
@@ -128,7 +132,7 @@ def update_timeout(
128
132
 
129
133
 
130
134
  def apply_distributed(
131
- config,
135
+ config: Distributed,
132
136
  model: nn.Module,
133
137
  device: torch.device,
134
138
  compile: bool,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sarasa
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Add your description here
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.13
@@ -46,6 +46,8 @@ uv add sarasa[cpu|cu128|cu130]
46
46
  - Async distributed checkpoint saving
47
47
 
48
48
  - [ ] Checkpoint loading
49
+ - [ ] FP8 training
50
+ - [ ] Profiling
49
51
 
50
52
  ## Usage
51
53
 
@@ -100,18 +102,22 @@ if __name__ == "__main__":
100
102
  trainer.train()
101
103
  ```
102
104
 
105
+ Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
103
106
  From the command line, you can specify which custom optimizer to use:
104
107
 
105
108
  ```bash
106
109
  python script.py optim:custom_optim --optim.lr 0.001 ...
107
110
  ```
108
111
 
112
+ (As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
113
+
109
114
  ### Config File Example
110
115
 
111
- It's very simple. IDE autocompletion will help you.
116
+ It's very simple.
117
+ IDE autocompletion will help you.
112
118
 
113
119
  ```python
114
- from sarasa.config import Config, Data, LRScheduler, Model, Train, LRScheduler
120
+ from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
115
121
  from custom_optim import CustomOptim
116
122
 
117
123
  # only one Config instance should be defined in each config file
@@ -135,4 +141,4 @@ config = Config.create(
135
141
 
136
142
  ## Acknowledgements
137
143
 
138
- This project is heavily inspired by and borrows code from `torchtitan`.
144
+ This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
@@ -0,0 +1,21 @@
1
+ sarasa/__init__.py,sha256=cCbISDgbexMLgX9ff5zGZXcyXym3qN6tbzbH0rh2a1k,408
2
+ sarasa/activation_checkpoint.py,sha256=iGib3e2GFxBLOtgcPLQZnzw0Ru6Gd_yFqWZmUw0Cfa4,3056
3
+ sarasa/checkpoint.py,sha256=EjntYpZDGe5fWJhlmyeHcgLA0vjJUR7gTpziT0xWcKw,3849
4
+ sarasa/config.py,sha256=7MNVs5GVCZ2ezpwGIGiyDCIqyRTi7vqnp2D6fWqHv5s,9299
5
+ sarasa/metrics.py,sha256=OzTuK3Oed-I_2FC7rrE9FYi3NgTdKsDsVkWlGgJGh0M,10636
6
+ sarasa/train.py,sha256=cRWUfNGGTNVUMZq77ltRoQaw5uZdn7N8zeWLXiP2F6M,9892
7
+ sarasa/utils.py,sha256=3bSTXjAD5hRXIQxPmEbu7Bt709J-ha0CuHU_6Ieqvjc,4542
8
+ sarasa/data/__init__.py,sha256=I0JOb9QrHEj9zXUX8kLir6ONAyiozeagzApig0WcSt8,1150
9
+ sarasa/data/hf_datasets.py,sha256=DUlCpBOcDtZNEGrx4AtZTPW5IMtxIXMX_pKfnQEqzEg,3966
10
+ sarasa/data/tokenizer.py,sha256=JhUOl9USJRM-DVPY02ouiaNUhAu1w2LLGquMnAyyA68,1752
11
+ sarasa/models/__init__.py,sha256=1jxzxW-lg1Am38CHQhC9K1VD3U2J0Ca70kGCifhxlRA,3396
12
+ sarasa/models/attention.py,sha256=9mI7j_k98OJlaUgRzWNvFHeVqBs07YJpMA9oefYJvaE,3109
13
+ sarasa/models/llama3.py,sha256=2WACwUiS5y8rEmO7EcbXvPR3JtxhL6Jz9qRB28qzPvM,5018
14
+ sarasa/models/nanochat_gpt.py,sha256=zKLgwYOHjohZKQaGwrRp-7uEobZhHKoWT1lvJiYJzZA,8936
15
+ sarasa/models/utils.py,sha256=4aJD8HrJ1JPsR9615UmIFiDy6bwnjRffKIlGoYwolUg,887
16
+ sarasa/optimizers/__init__.py,sha256=TH7CV-dexzVIm_NJKpo3VxnzwLvUjPyckD_oXMT48xo,1844
17
+ sarasa/optimizers/utils.py,sha256=yI1_yHllJFyGbFW8jdMbvLfa5k7zUXjdgkvbr64mFOI,705
18
+ sarasa-0.0.4.dist-info/METADATA,sha256=1D5-WB0KHgWhxDu8UcYp_wHHOCTvd7jXghzNq5CAvvo,3805
19
+ sarasa-0.0.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
20
+ sarasa-0.0.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
21
+ sarasa-0.0.4.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- sarasa/__init__.py,sha256=PrMUKyqcTKgKR3R3VMTzfuQ9EwMNVyw7qRl84Dz3d28,77
2
- sarasa/activation_checkpoint.py,sha256=iGib3e2GFxBLOtgcPLQZnzw0Ru6Gd_yFqWZmUw0Cfa4,3056
3
- sarasa/checkpoint.py,sha256=nZNo-qv3hvtzZuN0xw4SsCP4QmIM7E5nqGBxGfGYZo0,3616
4
- sarasa/config.py,sha256=7MNVs5GVCZ2ezpwGIGiyDCIqyRTi7vqnp2D6fWqHv5s,9299
5
- sarasa/metrics.py,sha256=OzTuK3Oed-I_2FC7rrE9FYi3NgTdKsDsVkWlGgJGh0M,10636
6
- sarasa/trainer.py,sha256=aRqa4ycbMeyLROP_hcehUpYif8jpGqCdNBHkt-TOTaw,9645
7
- sarasa/utils.py,sha256=Lee6wryHbob3PGpvR71gvpGjYNn-YsUaqoXvx6lx_tU,4431
8
- sarasa/data/__init__.py,sha256=I0JOb9QrHEj9zXUX8kLir6ONAyiozeagzApig0WcSt8,1150
9
- sarasa/data/hf_datasets.py,sha256=DUlCpBOcDtZNEGrx4AtZTPW5IMtxIXMX_pKfnQEqzEg,3966
10
- sarasa/data/tokenizer.py,sha256=JhUOl9USJRM-DVPY02ouiaNUhAu1w2LLGquMnAyyA68,1752
11
- sarasa/models/__init__.py,sha256=w9p4lZ0oEH2kRMxPh88Ogphwqx_o_Ik8Upfv9SrW7hA,3223
12
- sarasa/models/attention.py,sha256=rWm6NurkS5wdnzP_LPonCMJt_gQulySQDPFpVZtpGWU,3006
13
- sarasa/models/llama3.py,sha256=jGrrC2AQJvdyo_YvbGC4vmy23k-19Itsj_fSHUY2QTc,4509
14
- sarasa/models/nanochat_gpt.py,sha256=cpyoXwlWhqeUtgYSQ673AW6lh9BQ-64_9G9XURfZ1MY,8721
15
- sarasa/models/utils.py,sha256=_0F8yFVB2ZVClr8YppBhPOWtpNF0dSkmaElCYCkO_co,1111
16
- sarasa/optimizers/__init__.py,sha256=TH7CV-dexzVIm_NJKpo3VxnzwLvUjPyckD_oXMT48xo,1844
17
- sarasa/optimizers/utils.py,sha256=yI1_yHllJFyGbFW8jdMbvLfa5k7zUXjdgkvbr64mFOI,705
18
- sarasa-0.0.3.dist-info/METADATA,sha256=WXzEK22tSsKV1ZyWhnfsFTg5Ze9zYWJpe8MYjUH_V6M,3452
19
- sarasa-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
20
- sarasa-0.0.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
21
- sarasa-0.0.3.dist-info/RECORD,,
File without changes