litgpt 0.1.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lit_gpt/__init__.py +27 -0
- lit_gpt/adapter.py +168 -0
- lit_gpt/adapter_v2.py +224 -0
- lit_gpt/args.py +81 -0
- lit_gpt/config.py +1447 -0
- lit_gpt/lora.py +737 -0
- lit_gpt/model.py +390 -0
- lit_gpt/packed_dataset.py +239 -0
- lit_gpt/rmsnorm.py +34 -0
- lit_gpt/tokenizer.py +109 -0
- lit_gpt/utils.py +379 -0
- litgpt-0.1.0.dev1.dist-info/LICENSE +201 -0
- litgpt-0.1.0.dev1.dist-info/METADATA +284 -0
- litgpt-0.1.0.dev1.dist-info/RECORD +16 -0
- litgpt-0.1.0.dev1.dist-info/WHEEL +5 -0
- litgpt-0.1.0.dev1.dist-info/top_level.txt +1 -0
lit_gpt/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from lit_gpt.model import GPT
|
|
7
|
+
from lit_gpt.config import Config
|
|
8
|
+
from lit_gpt.tokenizer import Tokenizer
|
|
9
|
+
|
|
10
|
+
from lightning_utilities.core.imports import RequirementCache
|
|
11
|
+
|
|
12
|
+
_LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.2.0.dev0")
|
|
13
|
+
if not bool(_LIGHTNING_AVAILABLE):
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"Lit-GPT requires lightning nightly. Please run:\n"
|
|
16
|
+
f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
|
|
20
|
+
pattern = re.compile(".*Profiler function .* will be ignored")
|
|
21
|
+
logging.getLogger("torch._dynamo.variables.torch").addFilter(lambda record: not pattern.search(record.getMessage()))
|
|
22
|
+
|
|
23
|
+
# Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
|
|
24
|
+
logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
|
|
25
|
+
logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
|
|
26
|
+
|
|
27
|
+
__all__ = ["GPT", "Config", "Tokenizer"]
|
lit_gpt/adapter.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
|
2
|
+
|
|
3
|
+
"""Implementation of the paper:
|
|
4
|
+
|
|
5
|
+
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
|
|
6
|
+
https://arxiv.org/abs/2303.16199
|
|
7
|
+
|
|
8
|
+
Port for Lit-GPT
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
from lit_gpt.config import Config as BaseConfig
|
|
19
|
+
from lit_gpt.model import GPT as BaseModel
|
|
20
|
+
from lit_gpt.model import Block as BaseBlock
|
|
21
|
+
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Config(BaseConfig):
|
|
26
|
+
adapter_prompt_length: int = 10
|
|
27
|
+
adapter_start_layer: int = 2
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GPT(BaseModel):
|
|
31
|
+
"""The implementation is identical to `lit_gpt.model.GPT` with the exception that
|
|
32
|
+
the `Block` saves the layer index and passes it down to the attention layer."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: Config) -> None:
|
|
35
|
+
nn.Module.__init__(self)
|
|
36
|
+
assert config.padded_vocab_size is not None
|
|
37
|
+
self.config = config
|
|
38
|
+
|
|
39
|
+
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
|
|
40
|
+
self.transformer = nn.ModuleDict(
|
|
41
|
+
dict(
|
|
42
|
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
|
43
|
+
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
|
|
44
|
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
self.max_seq_length = self.config.block_size
|
|
48
|
+
self.mask_cache: Optional[torch.Tensor] = None
|
|
49
|
+
|
|
50
|
+
def forward(
|
|
51
|
+
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
|
|
52
|
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
53
|
+
T = idx.size(1)
|
|
54
|
+
if self.max_seq_length < T:
|
|
55
|
+
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
|
|
56
|
+
|
|
57
|
+
if input_pos is not None: # use the kv cache
|
|
58
|
+
cos = self.cos.index_select(0, input_pos)
|
|
59
|
+
sin = self.sin.index_select(0, input_pos)
|
|
60
|
+
if self.mask_cache is None:
|
|
61
|
+
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
|
62
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
63
|
+
else:
|
|
64
|
+
cos = self.cos[:T]
|
|
65
|
+
sin = self.sin[:T]
|
|
66
|
+
mask = None
|
|
67
|
+
|
|
68
|
+
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
|
69
|
+
for block in self.transformer.h:
|
|
70
|
+
x = block(x, cos, sin, mask, input_pos)
|
|
71
|
+
x = self.transformer.ln_f(x)
|
|
72
|
+
if lm_head_chunk_size > 0:
|
|
73
|
+
# chunk the lm head logits to reduce the peak memory used by autograd
|
|
74
|
+
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
|
|
75
|
+
return self.lm_head(x) # (b, t, vocab_size)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
|
79
|
+
return cls(Config.from_name(name, **kwargs))
|
|
80
|
+
|
|
81
|
+
def _init_weights(self, module: nn.Module) -> None:
|
|
82
|
+
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
|
|
83
|
+
super()._init_weights(module)
|
|
84
|
+
if isinstance(module, CausalSelfAttention):
|
|
85
|
+
module.reset_parameters()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class Block(BaseBlock):
|
|
89
|
+
"""The implementation is identical to `lit_gpt.model.Block` with the exception that
|
|
90
|
+
we replace the attention layer where adaption is implemented."""
|
|
91
|
+
|
|
92
|
+
def __init__(self, config: Config, block_idx: int) -> None:
|
|
93
|
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
|
94
|
+
nn.Module.__init__(self)
|
|
95
|
+
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
|
96
|
+
self.attn = CausalSelfAttention(config, block_idx)
|
|
97
|
+
if not config.shared_attention_norm:
|
|
98
|
+
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
|
99
|
+
self.mlp = config.mlp_class(config)
|
|
100
|
+
|
|
101
|
+
self.config = config
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class CausalSelfAttention(BaseCausalSelfAttention):
|
|
105
|
+
"""A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
|
|
106
|
+
over the adaption prompt."""
|
|
107
|
+
|
|
108
|
+
def __init__(self, config: Config, block_idx: int) -> None:
|
|
109
|
+
super().__init__(config)
|
|
110
|
+
if block_idx >= config.adapter_start_layer:
|
|
111
|
+
# adapter embedding layer
|
|
112
|
+
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
|
|
113
|
+
# gate for adaption
|
|
114
|
+
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
|
|
115
|
+
# kv cache for inference
|
|
116
|
+
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
|
117
|
+
self.block_idx = block_idx
|
|
118
|
+
|
|
119
|
+
def scaled_dot_product_attention(
|
|
120
|
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
|
|
121
|
+
) -> torch.Tensor:
|
|
122
|
+
y = super().scaled_dot_product_attention(q, k, v, mask)
|
|
123
|
+
if self.block_idx < self.config.adapter_start_layer:
|
|
124
|
+
return y
|
|
125
|
+
|
|
126
|
+
aT = self.config.adapter_prompt_length
|
|
127
|
+
if self.adapter_kv_cache is not None:
|
|
128
|
+
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
|
|
129
|
+
# are the same every call
|
|
130
|
+
ak, av = self.adapter_kv_cache
|
|
131
|
+
else:
|
|
132
|
+
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
|
|
133
|
+
aqkv = self.attn(prefix)
|
|
134
|
+
q_per_kv = self.config.n_head // self.config.n_query_groups
|
|
135
|
+
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
|
|
136
|
+
aqkv = aqkv.permute(0, 2, 3, 1, 4)
|
|
137
|
+
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
|
|
138
|
+
if self.config.n_query_groups != 1:
|
|
139
|
+
# for MHA this is a no-op
|
|
140
|
+
ak = ak.repeat_interleave(q_per_kv, dim=2)
|
|
141
|
+
av = av.repeat_interleave(q_per_kv, dim=2)
|
|
142
|
+
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
|
|
143
|
+
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
|
|
144
|
+
self.adapter_kv_cache = (ak, av)
|
|
145
|
+
|
|
146
|
+
T = q.size(2)
|
|
147
|
+
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
|
|
148
|
+
ay = super().scaled_dot_product_attention(q, ak, av, amask)
|
|
149
|
+
return y + self.gating_factor * ay
|
|
150
|
+
|
|
151
|
+
def reset_parameters(self) -> None:
|
|
152
|
+
torch.nn.init.zeros_(self.gating_factor)
|
|
153
|
+
|
|
154
|
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
|
155
|
+
"""For compatibility with older checkpoints."""
|
|
156
|
+
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
|
|
157
|
+
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
|
|
158
|
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def mark_only_adapter_as_trainable(model: GPT) -> None:
|
|
162
|
+
"""Sets `requires_grad=False` for all non-adapter weights."""
|
|
163
|
+
for name, param in model.named_parameters():
|
|
164
|
+
param.requires_grad = adapter_filter(name, param)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def adapter_filter(key: str, value: Any) -> bool:
|
|
168
|
+
return "adapter_wte" in key or "gating_factor" in key
|
lit_gpt/adapter_v2.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
|
2
|
+
|
|
3
|
+
"""Implementation of the paper:
|
|
4
|
+
|
|
5
|
+
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
|
|
6
|
+
https://arxiv.org/abs/2304.15010
|
|
7
|
+
|
|
8
|
+
Port for Lit-GPT
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Dict, Optional, Tuple, Type
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
import lit_gpt
|
|
19
|
+
from lit_gpt.adapter import GPT as BaseModel
|
|
20
|
+
from lit_gpt.adapter import Block as BaseBlock
|
|
21
|
+
from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
|
|
22
|
+
from lit_gpt.adapter import Config as BaseConfig
|
|
23
|
+
from lit_gpt.model import KVCache
|
|
24
|
+
from lit_gpt.utils import map_old_state_dict_weights
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Config(BaseConfig):
|
|
29
|
+
@property
|
|
30
|
+
def mlp_class(self) -> Type:
|
|
31
|
+
return getattr(lit_gpt.adapter_v2, self._mlp_class)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def adapter_filter(key: str, value: Any) -> bool:
|
|
35
|
+
adapter_substrings = (
|
|
36
|
+
# regular adapter v1 parameters
|
|
37
|
+
"adapter_wte",
|
|
38
|
+
"gating_factor",
|
|
39
|
+
# adapter v2: new bias and scale used in Linear
|
|
40
|
+
"adapter_scale",
|
|
41
|
+
"adapter_bias",
|
|
42
|
+
# adapter v2: Norm parameters are now trainable
|
|
43
|
+
"norm_1",
|
|
44
|
+
"norm_2",
|
|
45
|
+
"ln_f",
|
|
46
|
+
)
|
|
47
|
+
return any(s in key for s in adapter_substrings)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AdapterV2Linear(torch.nn.Module):
|
|
51
|
+
def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
|
|
54
|
+
self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
|
|
55
|
+
self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
|
|
56
|
+
|
|
57
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
return self.adapter_scale * (self.linear(x) + self.adapter_bias)
|
|
59
|
+
|
|
60
|
+
def reset_parameters(self) -> None:
|
|
61
|
+
nn.init.zeros_(self.adapter_bias)
|
|
62
|
+
nn.init.ones_(self.adapter_scale)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class GPT(BaseModel):
|
|
66
|
+
def __init__(self, config: Config) -> None:
|
|
67
|
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
|
68
|
+
nn.Module.__init__(self)
|
|
69
|
+
assert config.padded_vocab_size is not None
|
|
70
|
+
self.config = config
|
|
71
|
+
|
|
72
|
+
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
|
|
73
|
+
self.transformer = nn.ModuleDict(
|
|
74
|
+
dict(
|
|
75
|
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
|
76
|
+
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
|
|
77
|
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
self.max_seq_length = self.config.block_size
|
|
81
|
+
self.mask_cache: Optional[torch.Tensor] = None
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
|
85
|
+
return cls(Config.from_name(name, **kwargs))
|
|
86
|
+
|
|
87
|
+
def _init_weights(self, module: nn.Module) -> None:
|
|
88
|
+
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
|
|
89
|
+
super()._init_weights(module)
|
|
90
|
+
if isinstance(module, AdapterV2Linear):
|
|
91
|
+
module.reset_parameters()
|
|
92
|
+
|
|
93
|
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
|
94
|
+
"""For compatibility with base checkpoints."""
|
|
95
|
+
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
|
|
96
|
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
|
97
|
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Block(BaseBlock):
|
|
101
|
+
"""The implementation is identical to `lit_gpt.model.Block` with the exception that
|
|
102
|
+
we replace the attention layer where adaption is implemented."""
|
|
103
|
+
|
|
104
|
+
def __init__(self, config: Config, block_idx: int) -> None:
|
|
105
|
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
|
106
|
+
nn.Module.__init__(self)
|
|
107
|
+
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
|
108
|
+
self.attn = CausalSelfAttention(config, block_idx)
|
|
109
|
+
if not config.shared_attention_norm:
|
|
110
|
+
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
|
111
|
+
self.mlp = config.mlp_class(config)
|
|
112
|
+
|
|
113
|
+
self.config = config
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class CausalSelfAttention(BaseCausalSelfAttention):
|
|
117
|
+
"""A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
|
|
118
|
+
|
|
119
|
+
def __init__(self, config: Config, block_idx: int) -> None:
|
|
120
|
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
|
121
|
+
nn.Module.__init__(self)
|
|
122
|
+
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
|
|
123
|
+
# key, query, value projections for all heads, but in a batch
|
|
124
|
+
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
|
|
125
|
+
# output projection
|
|
126
|
+
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
|
|
127
|
+
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
|
|
128
|
+
# disabled by default
|
|
129
|
+
self.kv_cache: Optional[KVCache] = None
|
|
130
|
+
|
|
131
|
+
if block_idx >= config.adapter_start_layer:
|
|
132
|
+
# adapter embedding layer
|
|
133
|
+
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
|
|
134
|
+
# gate for adaption
|
|
135
|
+
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
|
|
136
|
+
# kv cache for inference
|
|
137
|
+
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
|
138
|
+
self.block_idx = block_idx
|
|
139
|
+
|
|
140
|
+
self.config = config
|
|
141
|
+
|
|
142
|
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
|
143
|
+
"""For compatibility with base checkpoints."""
|
|
144
|
+
mapping = {
|
|
145
|
+
"attn.weight": "attn.linear.weight",
|
|
146
|
+
"attn.bias": "attn.linear.bias",
|
|
147
|
+
"proj.weight": "proj.linear.weight",
|
|
148
|
+
"proj.bias": "proj.linear.bias",
|
|
149
|
+
}
|
|
150
|
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
|
151
|
+
# For compatibility with older checkpoints
|
|
152
|
+
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
|
|
153
|
+
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
|
|
154
|
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
|
|
158
|
+
def __init__(self, config: Config) -> None:
|
|
159
|
+
nn.Module.__init__(self)
|
|
160
|
+
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
|
161
|
+
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
|
162
|
+
|
|
163
|
+
self.config = config
|
|
164
|
+
|
|
165
|
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
|
166
|
+
"""For compatibility with base checkpoints."""
|
|
167
|
+
mapping = {
|
|
168
|
+
"fc.weight": "fc.linear.weight",
|
|
169
|
+
"fc.bias": "fc.linear.bias",
|
|
170
|
+
"proj.weight": "proj.linear.weight",
|
|
171
|
+
"proj.bias": "proj.linear.bias",
|
|
172
|
+
}
|
|
173
|
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
|
174
|
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class LLaMAMLP(lit_gpt.model.LLaMAMLP):
|
|
178
|
+
def __init__(self, config: Config) -> None:
|
|
179
|
+
nn.Module.__init__(self)
|
|
180
|
+
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
|
181
|
+
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
|
182
|
+
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
|
183
|
+
|
|
184
|
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
|
185
|
+
"""For compatibility with base checkpoints."""
|
|
186
|
+
mapping = {
|
|
187
|
+
"fc_1.weight": "fc_1.linear.weight",
|
|
188
|
+
"fc_1.bias": "fc_1.linear.bias",
|
|
189
|
+
"fc_2.weight": "fc_2.linear.weight",
|
|
190
|
+
"fc_2.bias": "fc_2.linear.bias",
|
|
191
|
+
"proj.weight": "proj.linear.weight",
|
|
192
|
+
"proj.bias": "proj.linear.bias",
|
|
193
|
+
}
|
|
194
|
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
|
195
|
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class GemmaMLP(LLaMAMLP):
|
|
199
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
200
|
+
x_fc_1 = self.fc_1(x)
|
|
201
|
+
x_fc_2 = self.fc_2(x)
|
|
202
|
+
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
|
|
203
|
+
return self.proj(x)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LLaMAMoE(lit_gpt.model.LLaMAMoE):
|
|
207
|
+
def __init__(self, config: Config) -> None:
|
|
208
|
+
nn.Module.__init__(self)
|
|
209
|
+
self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False)
|
|
210
|
+
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
|
|
211
|
+
|
|
212
|
+
self.config = config
|
|
213
|
+
|
|
214
|
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
|
215
|
+
"""For compatibility with base checkpoints."""
|
|
216
|
+
mapping = {"gate.weight": "gate.linear.weight"}
|
|
217
|
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
|
218
|
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
|
|
222
|
+
"""Sets requires_grad=False for all non-adapter weights"""
|
|
223
|
+
for name, param in model.named_parameters():
|
|
224
|
+
param.requires_grad = adapter_filter(name, param)
|
lit_gpt/args.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class TrainArgs:
|
|
8
|
+
"""Training related arguments"""
|
|
9
|
+
|
|
10
|
+
save_interval: int = 1000
|
|
11
|
+
"""Number of optimizer steps between checkpoints"""
|
|
12
|
+
log_interval: int = 1
|
|
13
|
+
"""Number of iterations between logging calls"""
|
|
14
|
+
global_batch_size: int = 64
|
|
15
|
+
"""Number of samples between optimizer steps across data-parallel ranks"""
|
|
16
|
+
micro_batch_size: int = 4
|
|
17
|
+
"""Number of samples per data-parallel rank"""
|
|
18
|
+
lr_warmup_steps: int = 100
|
|
19
|
+
"""Number of iterations with learning rate warmup active"""
|
|
20
|
+
epochs: Optional[int] = None
|
|
21
|
+
"""Number of epochs to run"""
|
|
22
|
+
epoch_size: Optional[int] = None
|
|
23
|
+
"""Size of the epoch"""
|
|
24
|
+
# TODO: pretrain/tinyllama is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs?
|
|
25
|
+
max_tokens: Optional[int] = None
|
|
26
|
+
"""Total number of tokens to train on"""
|
|
27
|
+
max_seq_length: Optional[int] = None
|
|
28
|
+
"""Limits the length of samples. Off by default"""
|
|
29
|
+
|
|
30
|
+
# Optimization args
|
|
31
|
+
learning_rate: float = 1e-3
|
|
32
|
+
weight_decay: float = 0.02
|
|
33
|
+
beta1: float = 0.9
|
|
34
|
+
beta2: float = 0.95
|
|
35
|
+
max_norm: Optional[float] = None
|
|
36
|
+
min_lr: float = 6e-5
|
|
37
|
+
|
|
38
|
+
def max_iters(self, devices: int) -> int:
|
|
39
|
+
"""Number of iterations"""
|
|
40
|
+
max_iters = self.epochs * self.epoch_size // devices // self.micro_batch_size
|
|
41
|
+
assert max_iters > 0
|
|
42
|
+
return max_iters
|
|
43
|
+
|
|
44
|
+
def gradient_accumulation_iters(self, devices: int) -> int:
|
|
45
|
+
"""Number of iterations between gradient synchronizations"""
|
|
46
|
+
gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size
|
|
47
|
+
assert gradient_accumulation_iters > 0
|
|
48
|
+
return gradient_accumulation_iters
|
|
49
|
+
|
|
50
|
+
def batch_size(self, devices: int) -> int:
|
|
51
|
+
"""Number of samples between optimizer steps per data-parallel rank"""
|
|
52
|
+
batch_size = self.global_batch_size // devices
|
|
53
|
+
assert batch_size > 0
|
|
54
|
+
return batch_size
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class EvalArgs:
|
|
59
|
+
"""Evaluation related arguments"""
|
|
60
|
+
|
|
61
|
+
interval: int = 600
|
|
62
|
+
"""Number of optimizer steps between evaluation calls"""
|
|
63
|
+
max_new_tokens: Optional[int] = None
|
|
64
|
+
"""Number of tokens to generate"""
|
|
65
|
+
max_iters: int = 100
|
|
66
|
+
"""Number of iterations"""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class IOArgs:
|
|
71
|
+
"""Inputs and outputs related arguments"""
|
|
72
|
+
|
|
73
|
+
# Optional because pretrain/tinyllama hardcodes the path
|
|
74
|
+
train_data_dir: Optional[Path] = Path("data/alpaca")
|
|
75
|
+
"""Where to read training data from"""
|
|
76
|
+
val_data_dir: Optional[Path] = None
|
|
77
|
+
"""Where to read validation data from"""
|
|
78
|
+
checkpoint_dir: Optional[Path] = None
|
|
79
|
+
"""Where to read weights and tokenizer data from"""
|
|
80
|
+
out_dir: Path = Path("out/adapter/alpaca")
|
|
81
|
+
"""Where to save artifacts"""
|