litgpt 0.1.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lit_gpt/__init__.py ADDED
@@ -0,0 +1,27 @@
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import re
4
+ import logging
5
+
6
+ from lit_gpt.model import GPT
7
+ from lit_gpt.config import Config
8
+ from lit_gpt.tokenizer import Tokenizer
9
+
10
+ from lightning_utilities.core.imports import RequirementCache
11
+
12
+ _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.2.0.dev0")
13
+ if not bool(_LIGHTNING_AVAILABLE):
14
+ raise ImportError(
15
+ "Lit-GPT requires lightning nightly. Please run:\n"
16
+ f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
17
+ )
18
+
19
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
20
+ pattern = re.compile(".*Profiler function .* will be ignored")
21
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(lambda record: not pattern.search(record.getMessage()))
22
+
23
+ # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
24
+ logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
25
+ logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
26
+
27
+ __all__ = ["GPT", "Config", "Tokenizer"]
lit_gpt/adapter.py ADDED
@@ -0,0 +1,168 @@
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Implementation of the paper:
4
+
5
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
6
+ https://arxiv.org/abs/2303.16199
7
+
8
+ Port for Lit-GPT
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from typing_extensions import Self
17
+
18
+ from lit_gpt.config import Config as BaseConfig
19
+ from lit_gpt.model import GPT as BaseModel
20
+ from lit_gpt.model import Block as BaseBlock
21
+ from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
22
+
23
+
24
+ @dataclass
25
+ class Config(BaseConfig):
26
+ adapter_prompt_length: int = 10
27
+ adapter_start_layer: int = 2
28
+
29
+
30
+ class GPT(BaseModel):
31
+ """The implementation is identical to `lit_gpt.model.GPT` with the exception that
32
+ the `Block` saves the layer index and passes it down to the attention layer."""
33
+
34
+ def __init__(self, config: Config) -> None:
35
+ nn.Module.__init__(self)
36
+ assert config.padded_vocab_size is not None
37
+ self.config = config
38
+
39
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
40
+ self.transformer = nn.ModuleDict(
41
+ dict(
42
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
43
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
44
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
45
+ )
46
+ )
47
+ self.max_seq_length = self.config.block_size
48
+ self.mask_cache: Optional[torch.Tensor] = None
49
+
50
+ def forward(
51
+ self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
52
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
53
+ T = idx.size(1)
54
+ if self.max_seq_length < T:
55
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
56
+
57
+ if input_pos is not None: # use the kv cache
58
+ cos = self.cos.index_select(0, input_pos)
59
+ sin = self.sin.index_select(0, input_pos)
60
+ if self.mask_cache is None:
61
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
62
+ mask = self.mask_cache.index_select(2, input_pos)
63
+ else:
64
+ cos = self.cos[:T]
65
+ sin = self.sin[:T]
66
+ mask = None
67
+
68
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
69
+ for block in self.transformer.h:
70
+ x = block(x, cos, sin, mask, input_pos)
71
+ x = self.transformer.ln_f(x)
72
+ if lm_head_chunk_size > 0:
73
+ # chunk the lm head logits to reduce the peak memory used by autograd
74
+ return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
75
+ return self.lm_head(x) # (b, t, vocab_size)
76
+
77
+ @classmethod
78
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
79
+ return cls(Config.from_name(name, **kwargs))
80
+
81
+ def _init_weights(self, module: nn.Module) -> None:
82
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
83
+ super()._init_weights(module)
84
+ if isinstance(module, CausalSelfAttention):
85
+ module.reset_parameters()
86
+
87
+
88
+ class Block(BaseBlock):
89
+ """The implementation is identical to `lit_gpt.model.Block` with the exception that
90
+ we replace the attention layer where adaption is implemented."""
91
+
92
+ def __init__(self, config: Config, block_idx: int) -> None:
93
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
94
+ nn.Module.__init__(self)
95
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
96
+ self.attn = CausalSelfAttention(config, block_idx)
97
+ if not config.shared_attention_norm:
98
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
99
+ self.mlp = config.mlp_class(config)
100
+
101
+ self.config = config
102
+
103
+
104
+ class CausalSelfAttention(BaseCausalSelfAttention):
105
+ """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
106
+ over the adaption prompt."""
107
+
108
+ def __init__(self, config: Config, block_idx: int) -> None:
109
+ super().__init__(config)
110
+ if block_idx >= config.adapter_start_layer:
111
+ # adapter embedding layer
112
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
113
+ # gate for adaption
114
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
115
+ # kv cache for inference
116
+ self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
117
+ self.block_idx = block_idx
118
+
119
+ def scaled_dot_product_attention(
120
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
121
+ ) -> torch.Tensor:
122
+ y = super().scaled_dot_product_attention(q, k, v, mask)
123
+ if self.block_idx < self.config.adapter_start_layer:
124
+ return y
125
+
126
+ aT = self.config.adapter_prompt_length
127
+ if self.adapter_kv_cache is not None:
128
+ # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
129
+ # are the same every call
130
+ ak, av = self.adapter_kv_cache
131
+ else:
132
+ prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
133
+ aqkv = self.attn(prefix)
134
+ q_per_kv = self.config.n_head // self.config.n_query_groups
135
+ aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
136
+ aqkv = aqkv.permute(0, 2, 3, 1, 4)
137
+ _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
138
+ if self.config.n_query_groups != 1:
139
+ # for MHA this is a no-op
140
+ ak = ak.repeat_interleave(q_per_kv, dim=2)
141
+ av = av.repeat_interleave(q_per_kv, dim=2)
142
+ ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
143
+ av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
144
+ self.adapter_kv_cache = (ak, av)
145
+
146
+ T = q.size(2)
147
+ amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
148
+ ay = super().scaled_dot_product_attention(q, ak, av, amask)
149
+ return y + self.gating_factor * ay
150
+
151
+ def reset_parameters(self) -> None:
152
+ torch.nn.init.zeros_(self.gating_factor)
153
+
154
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
155
+ """For compatibility with older checkpoints."""
156
+ if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
157
+ state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
158
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
159
+
160
+
161
+ def mark_only_adapter_as_trainable(model: GPT) -> None:
162
+ """Sets `requires_grad=False` for all non-adapter weights."""
163
+ for name, param in model.named_parameters():
164
+ param.requires_grad = adapter_filter(name, param)
165
+
166
+
167
+ def adapter_filter(key: str, value: Any) -> bool:
168
+ return "adapter_wte" in key or "gating_factor" in key
lit_gpt/adapter_v2.py ADDED
@@ -0,0 +1,224 @@
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Implementation of the paper:
4
+
5
+ LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
6
+ https://arxiv.org/abs/2304.15010
7
+
8
+ Port for Lit-GPT
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any, Dict, Optional, Tuple, Type
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from typing_extensions import Self
17
+
18
+ import lit_gpt
19
+ from lit_gpt.adapter import GPT as BaseModel
20
+ from lit_gpt.adapter import Block as BaseBlock
21
+ from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
22
+ from lit_gpt.adapter import Config as BaseConfig
23
+ from lit_gpt.model import KVCache
24
+ from lit_gpt.utils import map_old_state_dict_weights
25
+
26
+
27
+ @dataclass
28
+ class Config(BaseConfig):
29
+ @property
30
+ def mlp_class(self) -> Type:
31
+ return getattr(lit_gpt.adapter_v2, self._mlp_class)
32
+
33
+
34
+ def adapter_filter(key: str, value: Any) -> bool:
35
+ adapter_substrings = (
36
+ # regular adapter v1 parameters
37
+ "adapter_wte",
38
+ "gating_factor",
39
+ # adapter v2: new bias and scale used in Linear
40
+ "adapter_scale",
41
+ "adapter_bias",
42
+ # adapter v2: Norm parameters are now trainable
43
+ "norm_1",
44
+ "norm_2",
45
+ "ln_f",
46
+ )
47
+ return any(s in key for s in adapter_substrings)
48
+
49
+
50
+ class AdapterV2Linear(torch.nn.Module):
51
+ def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
52
+ super().__init__()
53
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
54
+ self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
55
+ self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ return self.adapter_scale * (self.linear(x) + self.adapter_bias)
59
+
60
+ def reset_parameters(self) -> None:
61
+ nn.init.zeros_(self.adapter_bias)
62
+ nn.init.ones_(self.adapter_scale)
63
+
64
+
65
+ class GPT(BaseModel):
66
+ def __init__(self, config: Config) -> None:
67
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
68
+ nn.Module.__init__(self)
69
+ assert config.padded_vocab_size is not None
70
+ self.config = config
71
+
72
+ self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
73
+ self.transformer = nn.ModuleDict(
74
+ dict(
75
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
76
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
77
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
78
+ )
79
+ )
80
+ self.max_seq_length = self.config.block_size
81
+ self.mask_cache: Optional[torch.Tensor] = None
82
+
83
+ @classmethod
84
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
85
+ return cls(Config.from_name(name, **kwargs))
86
+
87
+ def _init_weights(self, module: nn.Module) -> None:
88
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
89
+ super()._init_weights(module)
90
+ if isinstance(module, AdapterV2Linear):
91
+ module.reset_parameters()
92
+
93
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
94
+ """For compatibility with base checkpoints."""
95
+ mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
96
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
97
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
98
+
99
+
100
+ class Block(BaseBlock):
101
+ """The implementation is identical to `lit_gpt.model.Block` with the exception that
102
+ we replace the attention layer where adaption is implemented."""
103
+
104
+ def __init__(self, config: Config, block_idx: int) -> None:
105
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
106
+ nn.Module.__init__(self)
107
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
108
+ self.attn = CausalSelfAttention(config, block_idx)
109
+ if not config.shared_attention_norm:
110
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
111
+ self.mlp = config.mlp_class(config)
112
+
113
+ self.config = config
114
+
115
+
116
+ class CausalSelfAttention(BaseCausalSelfAttention):
117
+ """A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
118
+
119
+ def __init__(self, config: Config, block_idx: int) -> None:
120
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
121
+ nn.Module.__init__(self)
122
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
123
+ # key, query, value projections for all heads, but in a batch
124
+ self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
125
+ # output projection
126
+ # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
127
+ self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
128
+ # disabled by default
129
+ self.kv_cache: Optional[KVCache] = None
130
+
131
+ if block_idx >= config.adapter_start_layer:
132
+ # adapter embedding layer
133
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
134
+ # gate for adaption
135
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
136
+ # kv cache for inference
137
+ self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
138
+ self.block_idx = block_idx
139
+
140
+ self.config = config
141
+
142
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
143
+ """For compatibility with base checkpoints."""
144
+ mapping = {
145
+ "attn.weight": "attn.linear.weight",
146
+ "attn.bias": "attn.linear.bias",
147
+ "proj.weight": "proj.linear.weight",
148
+ "proj.bias": "proj.linear.bias",
149
+ }
150
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
151
+ # For compatibility with older checkpoints
152
+ if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
153
+ state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
154
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
155
+
156
+
157
+ class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
158
+ def __init__(self, config: Config) -> None:
159
+ nn.Module.__init__(self)
160
+ self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
161
+ self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
162
+
163
+ self.config = config
164
+
165
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
166
+ """For compatibility with base checkpoints."""
167
+ mapping = {
168
+ "fc.weight": "fc.linear.weight",
169
+ "fc.bias": "fc.linear.bias",
170
+ "proj.weight": "proj.linear.weight",
171
+ "proj.bias": "proj.linear.bias",
172
+ }
173
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
174
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
175
+
176
+
177
+ class LLaMAMLP(lit_gpt.model.LLaMAMLP):
178
+ def __init__(self, config: Config) -> None:
179
+ nn.Module.__init__(self)
180
+ self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
181
+ self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
182
+ self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
183
+
184
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
185
+ """For compatibility with base checkpoints."""
186
+ mapping = {
187
+ "fc_1.weight": "fc_1.linear.weight",
188
+ "fc_1.bias": "fc_1.linear.bias",
189
+ "fc_2.weight": "fc_2.linear.weight",
190
+ "fc_2.bias": "fc_2.linear.bias",
191
+ "proj.weight": "proj.linear.weight",
192
+ "proj.bias": "proj.linear.bias",
193
+ }
194
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
195
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
196
+
197
+
198
+ class GemmaMLP(LLaMAMLP):
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ x_fc_1 = self.fc_1(x)
201
+ x_fc_2 = self.fc_2(x)
202
+ x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
203
+ return self.proj(x)
204
+
205
+
206
+ class LLaMAMoE(lit_gpt.model.LLaMAMoE):
207
+ def __init__(self, config: Config) -> None:
208
+ nn.Module.__init__(self)
209
+ self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False)
210
+ self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
211
+
212
+ self.config = config
213
+
214
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
215
+ """For compatibility with base checkpoints."""
216
+ mapping = {"gate.weight": "gate.linear.weight"}
217
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
218
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
219
+
220
+
221
+ def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
222
+ """Sets requires_grad=False for all non-adapter weights"""
223
+ for name, param in model.named_parameters():
224
+ param.requires_grad = adapter_filter(name, param)
lit_gpt/args.py ADDED
@@ -0,0 +1,81 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+
6
+ @dataclass
7
+ class TrainArgs:
8
+ """Training related arguments"""
9
+
10
+ save_interval: int = 1000
11
+ """Number of optimizer steps between checkpoints"""
12
+ log_interval: int = 1
13
+ """Number of iterations between logging calls"""
14
+ global_batch_size: int = 64
15
+ """Number of samples between optimizer steps across data-parallel ranks"""
16
+ micro_batch_size: int = 4
17
+ """Number of samples per data-parallel rank"""
18
+ lr_warmup_steps: int = 100
19
+ """Number of iterations with learning rate warmup active"""
20
+ epochs: Optional[int] = None
21
+ """Number of epochs to run"""
22
+ epoch_size: Optional[int] = None
23
+ """Size of the epoch"""
24
+ # TODO: pretrain/tinyllama is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs?
25
+ max_tokens: Optional[int] = None
26
+ """Total number of tokens to train on"""
27
+ max_seq_length: Optional[int] = None
28
+ """Limits the length of samples. Off by default"""
29
+
30
+ # Optimization args
31
+ learning_rate: float = 1e-3
32
+ weight_decay: float = 0.02
33
+ beta1: float = 0.9
34
+ beta2: float = 0.95
35
+ max_norm: Optional[float] = None
36
+ min_lr: float = 6e-5
37
+
38
+ def max_iters(self, devices: int) -> int:
39
+ """Number of iterations"""
40
+ max_iters = self.epochs * self.epoch_size // devices // self.micro_batch_size
41
+ assert max_iters > 0
42
+ return max_iters
43
+
44
+ def gradient_accumulation_iters(self, devices: int) -> int:
45
+ """Number of iterations between gradient synchronizations"""
46
+ gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size
47
+ assert gradient_accumulation_iters > 0
48
+ return gradient_accumulation_iters
49
+
50
+ def batch_size(self, devices: int) -> int:
51
+ """Number of samples between optimizer steps per data-parallel rank"""
52
+ batch_size = self.global_batch_size // devices
53
+ assert batch_size > 0
54
+ return batch_size
55
+
56
+
57
+ @dataclass
58
+ class EvalArgs:
59
+ """Evaluation related arguments"""
60
+
61
+ interval: int = 600
62
+ """Number of optimizer steps between evaluation calls"""
63
+ max_new_tokens: Optional[int] = None
64
+ """Number of tokens to generate"""
65
+ max_iters: int = 100
66
+ """Number of iterations"""
67
+
68
+
69
+ @dataclass
70
+ class IOArgs:
71
+ """Inputs and outputs related arguments"""
72
+
73
+ # Optional because pretrain/tinyllama hardcodes the path
74
+ train_data_dir: Optional[Path] = Path("data/alpaca")
75
+ """Where to read training data from"""
76
+ val_data_dir: Optional[Path] = None
77
+ """Where to read validation data from"""
78
+ checkpoint_dir: Optional[Path] = None
79
+ """Where to read weights and tokenizer data from"""
80
+ out_dir: Path = Path("out/adapter/alpaca")
81
+ """Where to save artifacts"""