sarasa 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sarasa/__init__.py +10 -1
- sarasa/checkpoint.py +10 -0
- sarasa/models/__init__.py +2 -0
- sarasa/models/attention.py +6 -2
- sarasa/models/llama3.py +18 -7
- sarasa/models/nanochat_gpt.py +10 -1
- sarasa/models/utils.py +0 -9
- sarasa/{trainer.py → train.py} +26 -14
- sarasa/utils.py +6 -2
- {sarasa-0.0.3.dist-info → sarasa-0.0.4.dist-info}/METADATA +10 -4
- sarasa-0.0.4.dist-info/RECORD +21 -0
- sarasa-0.0.3.dist-info/RECORD +0 -21
- {sarasa-0.0.3.dist-info → sarasa-0.0.4.dist-info}/WHEEL +0 -0
- {sarasa-0.0.3.dist-info → sarasa-0.0.4.dist-info}/licenses/LICENSE +0 -0
sarasa/__init__.py
CHANGED
|
@@ -1,2 +1,11 @@
|
|
|
1
|
+
from .config import DDP as DDP
|
|
2
|
+
from .config import FSDP as FSDP
|
|
3
|
+
from .config import AdamW as AdamW
|
|
4
|
+
from .config import Checkpoint as Checkpoint
|
|
1
5
|
from .config import Config as Config
|
|
2
|
-
from .
|
|
6
|
+
from .config import Data as Data
|
|
7
|
+
from .config import LRScheduler as LRScheduler
|
|
8
|
+
from .config import Metrics as Metrics
|
|
9
|
+
from .config import Model as Model
|
|
10
|
+
from .config import Train as Train
|
|
11
|
+
from .train import Trainer as Trainer
|
sarasa/checkpoint.py
CHANGED
|
@@ -110,3 +110,13 @@ class Checkpointer:
|
|
|
110
110
|
def close(self) -> None:
|
|
111
111
|
if self.stager is not None:
|
|
112
112
|
self.stager.close()
|
|
113
|
+
|
|
114
|
+
if self.save_future is not None:
|
|
115
|
+
self.save_future.result()
|
|
116
|
+
|
|
117
|
+
if self.pg is not None:
|
|
118
|
+
dist.destroy_process_group(self.pg)
|
|
119
|
+
self.pg = None
|
|
120
|
+
|
|
121
|
+
def __del__(self):
|
|
122
|
+
self.close()
|
sarasa/models/__init__.py
CHANGED
|
@@ -21,6 +21,8 @@ class ModelConfig:
|
|
|
21
21
|
vocab_size: int | None = None # set later based on tokenizer
|
|
22
22
|
seq_len: int | None = None # set later based on data config
|
|
23
23
|
qk_norm: bool = False # whether to use RMSNorm on q/k
|
|
24
|
+
rms_eps: float | None = None # epsilon for RMSNorm, default to library default if None
|
|
25
|
+
rms_learnable: bool = False # whether RMSNorm has learnable scale parameter
|
|
24
26
|
|
|
25
27
|
def __post_init__(self):
|
|
26
28
|
# infer hidden_dim, num_heads, num_kv_heads if not provided using the rules presented in nanochat
|
sarasa/models/attention.py
CHANGED
|
@@ -3,7 +3,7 @@ from torch import nn
|
|
|
3
3
|
from torch.nn import functional as F
|
|
4
4
|
|
|
5
5
|
from sarasa.models import ModelConfig
|
|
6
|
-
from sarasa.models.utils import
|
|
6
|
+
from sarasa.models.utils import RoPE
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class SDPAttention(nn.Module):
|
|
@@ -57,7 +57,11 @@ class CausalSelfAttention(nn.Module):
|
|
|
57
57
|
self.c_k = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
58
58
|
self.c_v = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
59
59
|
self.c_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
|
|
60
|
-
self.qk_norm =
|
|
60
|
+
self.qk_norm = (
|
|
61
|
+
nn.RMSNorm(self.head_dim, eps=config.rms_eps, elementwise_affine=config.rms_learnable)
|
|
62
|
+
if config.qk_norm
|
|
63
|
+
else nn.Identity()
|
|
64
|
+
)
|
|
61
65
|
|
|
62
66
|
# todo: support varlen etc and kv caching
|
|
63
67
|
self.attn = SDPAttention(is_causal=True, enable_gqa=self.num_heads != self.num_kv_heads)
|
sarasa/models/llama3.py
CHANGED
|
@@ -4,7 +4,7 @@ from torch.nn import functional as F
|
|
|
4
4
|
|
|
5
5
|
from sarasa.models import BaseModel, ModelConfig
|
|
6
6
|
from sarasa.models.attention import CausalSelfAttention
|
|
7
|
-
from sarasa.models.utils import
|
|
7
|
+
from sarasa.models.utils import RoPE
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class MLP(nn.Module):
|
|
@@ -41,15 +41,16 @@ class Block(nn.Module):
|
|
|
41
41
|
self.layer_idx = layer_idx
|
|
42
42
|
self.attention = CausalSelfAttention(config)
|
|
43
43
|
self.mlp = MLP(config, multiple_of, ffn_dim_multiplier)
|
|
44
|
-
self.
|
|
44
|
+
self.attn_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
|
|
45
|
+
self.mlp_norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
|
|
45
46
|
|
|
46
47
|
def forward(
|
|
47
48
|
self,
|
|
48
49
|
x: torch.Tensor,
|
|
49
50
|
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
|
50
51
|
) -> torch.Tensor:
|
|
51
|
-
x = x + self.attention(self.
|
|
52
|
-
x = x + self.mlp(self.
|
|
52
|
+
x = x + self.attention(self.attn_norm(x), cos_sin)
|
|
53
|
+
x = x + self.mlp(self.mlp_norm(x))
|
|
53
54
|
return x
|
|
54
55
|
|
|
55
56
|
|
|
@@ -71,7 +72,7 @@ class Llama3(BaseModel):
|
|
|
71
72
|
self.blocks = nn.ModuleList([
|
|
72
73
|
Block(config, layer_idx, multiple_of, ffn_dim_multiplier) for layer_idx in range(config.num_layers)
|
|
73
74
|
])
|
|
74
|
-
self.norm = RMSNorm(config.hidden_dim)
|
|
75
|
+
self.norm = nn.RMSNorm(config.hidden_dim, eps=config.rms_eps)
|
|
75
76
|
self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
|
76
77
|
|
|
77
78
|
@torch.no_grad()
|
|
@@ -101,16 +102,26 @@ class Llama3(BaseModel):
|
|
|
101
102
|
b=cutoff_factor * final_out_std,
|
|
102
103
|
)
|
|
103
104
|
|
|
105
|
+
for mod in self.modules():
|
|
106
|
+
if isinstance(mod, nn.RMSNorm):
|
|
107
|
+
mod.reset_parameters()
|
|
108
|
+
|
|
104
109
|
def param_groups(self) -> dict[str, list[nn.Parameter]]:
|
|
105
|
-
matrix_params =
|
|
110
|
+
matrix_params = [param for param in self.blocks.parameters() if param.ndim == 2]
|
|
106
111
|
embedding_params = list(self.token_emb.parameters())
|
|
107
112
|
lm_head_params = list(self.output.parameters())
|
|
108
|
-
|
|
113
|
+
rms_norm_params = [
|
|
114
|
+
mod.weight for mod in self.modules() if isinstance(mod, nn.RMSNorm) and mod.elementwise_affine
|
|
115
|
+
]
|
|
116
|
+
assert len(list(self.parameters())) == (
|
|
117
|
+
len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(rms_norm_params)
|
|
118
|
+
)
|
|
109
119
|
|
|
110
120
|
return {
|
|
111
121
|
"matrix": matrix_params,
|
|
112
122
|
"embedding": embedding_params,
|
|
113
123
|
"lm_head": lm_head_params,
|
|
124
|
+
"rms_norm": rms_norm_params,
|
|
114
125
|
}
|
|
115
126
|
|
|
116
127
|
def forward(
|
sarasa/models/nanochat_gpt.py
CHANGED
|
@@ -8,7 +8,16 @@ from torch.nn import functional as F
|
|
|
8
8
|
|
|
9
9
|
from sarasa.models import BaseModel, ModelConfig
|
|
10
10
|
from sarasa.models.attention import CausalSelfAttention
|
|
11
|
-
from sarasa.models.utils import
|
|
11
|
+
from sarasa.models.utils import RoPE
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RMSNorm(torch.nn.RMSNorm):
|
|
15
|
+
# RMSNorm without affine parameters
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
normalized_shape: int,
|
|
19
|
+
):
|
|
20
|
+
super().__init__(normalized_shape, eps=None, elementwise_affine=False)
|
|
12
21
|
|
|
13
22
|
|
|
14
23
|
class MLP(nn.Module):
|
sarasa/models/utils.py
CHANGED
|
@@ -1,15 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
class RMSNorm(torch.nn.RMSNorm):
|
|
5
|
-
# RMSNorm without affine parameters
|
|
6
|
-
def __init__(
|
|
7
|
-
self,
|
|
8
|
-
normalized_shape: int,
|
|
9
|
-
):
|
|
10
|
-
super().__init__(normalized_shape, eps=None, elementwise_affine=False)
|
|
11
|
-
|
|
12
|
-
|
|
13
4
|
class RoPE:
|
|
14
5
|
@staticmethod
|
|
15
6
|
def precompute(
|
sarasa/{trainer.py → train.py}
RENAMED
|
@@ -7,6 +7,7 @@ import torch
|
|
|
7
7
|
import torch.distributed as dist
|
|
8
8
|
from loguru import logger
|
|
9
9
|
from torch.distributed.elastic.multiprocessing.errors import record
|
|
10
|
+
from torch.nn import functional as F
|
|
10
11
|
|
|
11
12
|
from sarasa.activation_checkpoint import apply_op_sac
|
|
12
13
|
from sarasa.checkpoint import Checkpointer
|
|
@@ -48,9 +49,6 @@ class Trainer:
|
|
|
48
49
|
vocab_size = len(self.tokenizer)
|
|
49
50
|
self.config.model.vocab_size = vocab_size
|
|
50
51
|
|
|
51
|
-
# todo: support other loss functions
|
|
52
|
-
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction="sum")
|
|
53
|
-
|
|
54
52
|
# setup model, optimizer, lr scheduler
|
|
55
53
|
with torch.device("meta"), set_dtype(getattr(torch, config.train.dtype)):
|
|
56
54
|
self.model = self.config.model.create()
|
|
@@ -68,9 +66,9 @@ class Trainer:
|
|
|
68
66
|
if config.train.compile:
|
|
69
67
|
logger.info("Compiling the model")
|
|
70
68
|
for block in self.model.blocks:
|
|
71
|
-
block.compile(fullgraph=True)
|
|
69
|
+
block.compile(fullgraph=True, dynamic=False)
|
|
72
70
|
self.model.compile(dynamic=False)
|
|
73
|
-
self.loss_fn.compile()
|
|
71
|
+
self.loss_fn = torch.compile(self.loss_fn, fullgraph=True, dynamic=False)
|
|
74
72
|
|
|
75
73
|
if world_size() > 1:
|
|
76
74
|
apply_distributed(
|
|
@@ -119,14 +117,6 @@ class Trainer:
|
|
|
119
117
|
f"Failed to activate FA4 flash attention: {e}. Install sarasa with `flash_attn` extra for better performance."
|
|
120
118
|
)
|
|
121
119
|
|
|
122
|
-
def __del__(self) -> None:
|
|
123
|
-
# cleanup distributed
|
|
124
|
-
if world_size() > 1:
|
|
125
|
-
try:
|
|
126
|
-
dist.destroy_process_group()
|
|
127
|
-
except Exception as e:
|
|
128
|
-
logger.warning(f"Failed to destroy process group: {e}")
|
|
129
|
-
|
|
130
120
|
@record
|
|
131
121
|
def train(self):
|
|
132
122
|
try:
|
|
@@ -197,7 +187,7 @@ class Trainer:
|
|
|
197
187
|
|
|
198
188
|
with self.amp_context:
|
|
199
189
|
pred = self.model(**input_dict)
|
|
200
|
-
loss = self.loss_fn(pred
|
|
190
|
+
loss = self.loss_fn(pred, target) / valid_tokens
|
|
201
191
|
|
|
202
192
|
del pred
|
|
203
193
|
loss.backward()
|
|
@@ -241,6 +231,18 @@ class Trainer:
|
|
|
241
231
|
},
|
|
242
232
|
)
|
|
243
233
|
|
|
234
|
+
def loss_fn(
|
|
235
|
+
self,
|
|
236
|
+
pred: torch.Tensor,
|
|
237
|
+
target: torch.Tensor,
|
|
238
|
+
) -> torch.Tensor:
|
|
239
|
+
return F.cross_entropy(
|
|
240
|
+
pred.flatten(0, 1).float(),
|
|
241
|
+
target.flatten(0, 1),
|
|
242
|
+
ignore_index=IGNORE_INDEX,
|
|
243
|
+
reduction="sum",
|
|
244
|
+
)
|
|
245
|
+
|
|
244
246
|
def evaluate(self):
|
|
245
247
|
raise NotImplementedError
|
|
246
248
|
|
|
@@ -256,3 +258,13 @@ class Trainer:
|
|
|
256
258
|
|
|
257
259
|
if self.metrics_processor is not None:
|
|
258
260
|
self.metrics_processor.close()
|
|
261
|
+
|
|
262
|
+
# cleanup distributed
|
|
263
|
+
if world_size() > 1:
|
|
264
|
+
try:
|
|
265
|
+
dist.destroy_process_group()
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.warning(f"Failed to destroy process group: {e}")
|
|
268
|
+
|
|
269
|
+
def __del__(self):
|
|
270
|
+
self.close()
|
sarasa/utils.py
CHANGED
|
@@ -3,6 +3,7 @@ import gc
|
|
|
3
3
|
import os
|
|
4
4
|
import sys
|
|
5
5
|
import time
|
|
6
|
+
import typing
|
|
6
7
|
from datetime import timedelta
|
|
7
8
|
from functools import cache
|
|
8
9
|
|
|
@@ -11,8 +12,11 @@ from loguru import logger
|
|
|
11
12
|
from torch import distributed as dist
|
|
12
13
|
from torch import nn
|
|
13
14
|
|
|
15
|
+
if typing.TYPE_CHECKING:
|
|
16
|
+
from sarasa.config import Config, Distributed
|
|
14
17
|
|
|
15
|
-
|
|
18
|
+
|
|
19
|
+
def setup_logger(config: Config) -> None:
|
|
16
20
|
logger.remove()
|
|
17
21
|
if config.debug:
|
|
18
22
|
logger_format = f"<blue>RANK={rank()}</blue> | " + (
|
|
@@ -128,7 +132,7 @@ def update_timeout(
|
|
|
128
132
|
|
|
129
133
|
|
|
130
134
|
def apply_distributed(
|
|
131
|
-
config,
|
|
135
|
+
config: Distributed,
|
|
132
136
|
model: nn.Module,
|
|
133
137
|
device: torch.device,
|
|
134
138
|
compile: bool,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sarasa
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: Add your description here
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Python: >=3.13
|
|
@@ -46,6 +46,8 @@ uv add sarasa[cpu|cu128|cu130]
|
|
|
46
46
|
- Async distributed checkpoint saving
|
|
47
47
|
|
|
48
48
|
- [ ] Checkpoint loading
|
|
49
|
+
- [ ] FP8 training
|
|
50
|
+
- [ ] Profiling
|
|
49
51
|
|
|
50
52
|
## Usage
|
|
51
53
|
|
|
@@ -100,18 +102,22 @@ if __name__ == "__main__":
|
|
|
100
102
|
trainer.train()
|
|
101
103
|
```
|
|
102
104
|
|
|
105
|
+
Thanks to [tyro](https://github.com/brentyi/tyro)'s type support, Sarasa can automatically recognize multiple custom optimizer types.
|
|
103
106
|
From the command line, you can specify which custom optimizer to use:
|
|
104
107
|
|
|
105
108
|
```bash
|
|
106
109
|
python script.py optim:custom_optim --optim.lr 0.001 ...
|
|
107
110
|
```
|
|
108
111
|
|
|
112
|
+
(As tyro automatically converts config class names from CamelCase to snake_case, config class names are recommended not to include `Config` suffixes.)
|
|
113
|
+
|
|
109
114
|
### Config File Example
|
|
110
115
|
|
|
111
|
-
It's very simple.
|
|
116
|
+
It's very simple.
|
|
117
|
+
IDE autocompletion will help you.
|
|
112
118
|
|
|
113
119
|
```python
|
|
114
|
-
from sarasa
|
|
120
|
+
from sarasa import Config, Data, LRScheduler, Model, Train, LRScheduler
|
|
115
121
|
from custom_optim import CustomOptim
|
|
116
122
|
|
|
117
123
|
# only one Config instance should be defined in each config file
|
|
@@ -135,4 +141,4 @@ config = Config.create(
|
|
|
135
141
|
|
|
136
142
|
## Acknowledgements
|
|
137
143
|
|
|
138
|
-
This project is heavily inspired by and borrows code from
|
|
144
|
+
This project is heavily inspired by and borrows code from [torchtitan](https://github.com/pytorch/torchtitan).
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
sarasa/__init__.py,sha256=cCbISDgbexMLgX9ff5zGZXcyXym3qN6tbzbH0rh2a1k,408
|
|
2
|
+
sarasa/activation_checkpoint.py,sha256=iGib3e2GFxBLOtgcPLQZnzw0Ru6Gd_yFqWZmUw0Cfa4,3056
|
|
3
|
+
sarasa/checkpoint.py,sha256=EjntYpZDGe5fWJhlmyeHcgLA0vjJUR7gTpziT0xWcKw,3849
|
|
4
|
+
sarasa/config.py,sha256=7MNVs5GVCZ2ezpwGIGiyDCIqyRTi7vqnp2D6fWqHv5s,9299
|
|
5
|
+
sarasa/metrics.py,sha256=OzTuK3Oed-I_2FC7rrE9FYi3NgTdKsDsVkWlGgJGh0M,10636
|
|
6
|
+
sarasa/train.py,sha256=cRWUfNGGTNVUMZq77ltRoQaw5uZdn7N8zeWLXiP2F6M,9892
|
|
7
|
+
sarasa/utils.py,sha256=3bSTXjAD5hRXIQxPmEbu7Bt709J-ha0CuHU_6Ieqvjc,4542
|
|
8
|
+
sarasa/data/__init__.py,sha256=I0JOb9QrHEj9zXUX8kLir6ONAyiozeagzApig0WcSt8,1150
|
|
9
|
+
sarasa/data/hf_datasets.py,sha256=DUlCpBOcDtZNEGrx4AtZTPW5IMtxIXMX_pKfnQEqzEg,3966
|
|
10
|
+
sarasa/data/tokenizer.py,sha256=JhUOl9USJRM-DVPY02ouiaNUhAu1w2LLGquMnAyyA68,1752
|
|
11
|
+
sarasa/models/__init__.py,sha256=1jxzxW-lg1Am38CHQhC9K1VD3U2J0Ca70kGCifhxlRA,3396
|
|
12
|
+
sarasa/models/attention.py,sha256=9mI7j_k98OJlaUgRzWNvFHeVqBs07YJpMA9oefYJvaE,3109
|
|
13
|
+
sarasa/models/llama3.py,sha256=2WACwUiS5y8rEmO7EcbXvPR3JtxhL6Jz9qRB28qzPvM,5018
|
|
14
|
+
sarasa/models/nanochat_gpt.py,sha256=zKLgwYOHjohZKQaGwrRp-7uEobZhHKoWT1lvJiYJzZA,8936
|
|
15
|
+
sarasa/models/utils.py,sha256=4aJD8HrJ1JPsR9615UmIFiDy6bwnjRffKIlGoYwolUg,887
|
|
16
|
+
sarasa/optimizers/__init__.py,sha256=TH7CV-dexzVIm_NJKpo3VxnzwLvUjPyckD_oXMT48xo,1844
|
|
17
|
+
sarasa/optimizers/utils.py,sha256=yI1_yHllJFyGbFW8jdMbvLfa5k7zUXjdgkvbr64mFOI,705
|
|
18
|
+
sarasa-0.0.4.dist-info/METADATA,sha256=1D5-WB0KHgWhxDu8UcYp_wHHOCTvd7jXghzNq5CAvvo,3805
|
|
19
|
+
sarasa-0.0.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
20
|
+
sarasa-0.0.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
21
|
+
sarasa-0.0.4.dist-info/RECORD,,
|
sarasa-0.0.3.dist-info/RECORD
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
sarasa/__init__.py,sha256=PrMUKyqcTKgKR3R3VMTzfuQ9EwMNVyw7qRl84Dz3d28,77
|
|
2
|
-
sarasa/activation_checkpoint.py,sha256=iGib3e2GFxBLOtgcPLQZnzw0Ru6Gd_yFqWZmUw0Cfa4,3056
|
|
3
|
-
sarasa/checkpoint.py,sha256=nZNo-qv3hvtzZuN0xw4SsCP4QmIM7E5nqGBxGfGYZo0,3616
|
|
4
|
-
sarasa/config.py,sha256=7MNVs5GVCZ2ezpwGIGiyDCIqyRTi7vqnp2D6fWqHv5s,9299
|
|
5
|
-
sarasa/metrics.py,sha256=OzTuK3Oed-I_2FC7rrE9FYi3NgTdKsDsVkWlGgJGh0M,10636
|
|
6
|
-
sarasa/trainer.py,sha256=aRqa4ycbMeyLROP_hcehUpYif8jpGqCdNBHkt-TOTaw,9645
|
|
7
|
-
sarasa/utils.py,sha256=Lee6wryHbob3PGpvR71gvpGjYNn-YsUaqoXvx6lx_tU,4431
|
|
8
|
-
sarasa/data/__init__.py,sha256=I0JOb9QrHEj9zXUX8kLir6ONAyiozeagzApig0WcSt8,1150
|
|
9
|
-
sarasa/data/hf_datasets.py,sha256=DUlCpBOcDtZNEGrx4AtZTPW5IMtxIXMX_pKfnQEqzEg,3966
|
|
10
|
-
sarasa/data/tokenizer.py,sha256=JhUOl9USJRM-DVPY02ouiaNUhAu1w2LLGquMnAyyA68,1752
|
|
11
|
-
sarasa/models/__init__.py,sha256=w9p4lZ0oEH2kRMxPh88Ogphwqx_o_Ik8Upfv9SrW7hA,3223
|
|
12
|
-
sarasa/models/attention.py,sha256=rWm6NurkS5wdnzP_LPonCMJt_gQulySQDPFpVZtpGWU,3006
|
|
13
|
-
sarasa/models/llama3.py,sha256=jGrrC2AQJvdyo_YvbGC4vmy23k-19Itsj_fSHUY2QTc,4509
|
|
14
|
-
sarasa/models/nanochat_gpt.py,sha256=cpyoXwlWhqeUtgYSQ673AW6lh9BQ-64_9G9XURfZ1MY,8721
|
|
15
|
-
sarasa/models/utils.py,sha256=_0F8yFVB2ZVClr8YppBhPOWtpNF0dSkmaElCYCkO_co,1111
|
|
16
|
-
sarasa/optimizers/__init__.py,sha256=TH7CV-dexzVIm_NJKpo3VxnzwLvUjPyckD_oXMT48xo,1844
|
|
17
|
-
sarasa/optimizers/utils.py,sha256=yI1_yHllJFyGbFW8jdMbvLfa5k7zUXjdgkvbr64mFOI,705
|
|
18
|
-
sarasa-0.0.3.dist-info/METADATA,sha256=WXzEK22tSsKV1ZyWhnfsFTg5Ze9zYWJpe8MYjUH_V6M,3452
|
|
19
|
-
sarasa-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
20
|
-
sarasa-0.0.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
21
|
-
sarasa-0.0.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|