erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import sys
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.amp import autocast
|
|
8
|
+
from torch.nn import Module
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
from erictransformer.eval_models import EvalModel
|
|
13
|
+
from erictransformer.exceptions import EricEvalError
|
|
14
|
+
from erictransformer.utils import get_precision
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(kw_only=True)
|
|
18
|
+
class EvalResult:
|
|
19
|
+
loss: float
|
|
20
|
+
metrics: Optional[dict] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def eval_loop(
|
|
24
|
+
model: Module,
|
|
25
|
+
dataloader: DataLoader,
|
|
26
|
+
eval_tok_data_cases: int,
|
|
27
|
+
eval_bs: int,
|
|
28
|
+
eval_models: List[EvalModel],
|
|
29
|
+
precision: str = "auto",
|
|
30
|
+
) -> EvalResult:
|
|
31
|
+
total_loss = 0.0
|
|
32
|
+
total_examples = 0
|
|
33
|
+
pbar = tqdm(
|
|
34
|
+
total=math.ceil(eval_tok_data_cases / eval_bs),
|
|
35
|
+
desc="Model Evaluating",
|
|
36
|
+
position=1,
|
|
37
|
+
leave=False,
|
|
38
|
+
disable=not sys.stdout.isatty(),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
for eval_model in eval_models:
|
|
42
|
+
eval_model.reset()
|
|
43
|
+
|
|
44
|
+
model.eval()
|
|
45
|
+
|
|
46
|
+
device = next(model.parameters()).device
|
|
47
|
+
precision_type = get_precision(device=device)
|
|
48
|
+
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
for batch in dataloader:
|
|
51
|
+
try:
|
|
52
|
+
try:
|
|
53
|
+
input_ids = batch["input_ids"].to(device, non_blocking=True)
|
|
54
|
+
attention_mask = batch["attention_mask"].to(
|
|
55
|
+
device, non_blocking=True
|
|
56
|
+
)
|
|
57
|
+
labels = batch["labels"].to(device, non_blocking=True)
|
|
58
|
+
except Exception as e:
|
|
59
|
+
raise EricEvalError(f"Failed to extract inputs from batch: {e}")
|
|
60
|
+
|
|
61
|
+
if device.type == "cuda":
|
|
62
|
+
dtype = (
|
|
63
|
+
torch.bfloat16 if precision_type == "bf16" else
|
|
64
|
+
torch.float16 if precision_type == "fp16" else
|
|
65
|
+
torch.float32
|
|
66
|
+
)
|
|
67
|
+
with autocast(device_type="cuda", enabled=(dtype != torch.float32), dtype=dtype):
|
|
68
|
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
|
69
|
+
else:
|
|
70
|
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
|
71
|
+
|
|
72
|
+
bs = batch["input_ids"].size(0)
|
|
73
|
+
result = outputs.loss.item() * bs
|
|
74
|
+
except Exception as e:
|
|
75
|
+
raise EricEvalError(f"Error processing batch: {e}")
|
|
76
|
+
|
|
77
|
+
if not math.isnan(result):
|
|
78
|
+
total_loss += result
|
|
79
|
+
total_examples += bs
|
|
80
|
+
if eval_models:
|
|
81
|
+
batch["input_ids"] = batch["input_ids"].to(device)
|
|
82
|
+
batch["attention_mask"] = batch["attention_mask"].to(device)
|
|
83
|
+
batch["labels"] = batch["labels"].to(device)
|
|
84
|
+
|
|
85
|
+
for eval_model in eval_models:
|
|
86
|
+
eval_model(batch, outputs)
|
|
87
|
+
|
|
88
|
+
pbar.update(1)
|
|
89
|
+
|
|
90
|
+
pbar.close()
|
|
91
|
+
|
|
92
|
+
custom_metrics = {}
|
|
93
|
+
for eval_model in eval_models:
|
|
94
|
+
name = eval_model.name
|
|
95
|
+
if name in custom_metrics:
|
|
96
|
+
for i in range(0, 1000):
|
|
97
|
+
new_name = f"{eval_model.name}-{i}"
|
|
98
|
+
if new_name not in custom_metrics:
|
|
99
|
+
name = new_name
|
|
100
|
+
break
|
|
101
|
+
else:
|
|
102
|
+
raise EricEvalError(
|
|
103
|
+
f"Couldn't find unique name for eval model named {eval_model.name}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
custom_metrics[name] = eval_model.result()
|
|
107
|
+
eval_model.reset()
|
|
108
|
+
|
|
109
|
+
average_loss = total_loss / total_examples if total_examples else 0.0
|
|
110
|
+
|
|
111
|
+
return EvalResult(loss=average_loss, metrics=custom_metrics)
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.nn import Module
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
from transformers import PretrainedConfig, PreTrainedTokenizerBase
|
|
12
|
+
|
|
13
|
+
from erictransformer.args.eric_args import EricTrainArgs
|
|
14
|
+
from erictransformer.eval_models import EvalModel
|
|
15
|
+
from erictransformer.exceptions import EricTrainError
|
|
16
|
+
from erictransformer.eric_tracker.eric_tracker import EricTracker
|
|
17
|
+
from erictransformer.loops.eval_loop import eval_loop
|
|
18
|
+
from erictransformer.utils import EricTimer
|
|
19
|
+
from erictransformer.utils.test import DebugHook
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class TrainResult:
|
|
24
|
+
final_train_loss: float
|
|
25
|
+
final_eval_loss: Optional[float] = None
|
|
26
|
+
best_eval_loss: Optional[float] = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def train_loop(
|
|
30
|
+
args: EricTrainArgs,
|
|
31
|
+
train_dataloader: DataLoader,
|
|
32
|
+
eric_tracker: EricTracker,
|
|
33
|
+
train_steps: int,
|
|
34
|
+
model: Module,
|
|
35
|
+
optim: Optimizer,
|
|
36
|
+
lr_sched: LRScheduler,
|
|
37
|
+
eval_cases: int,
|
|
38
|
+
eval_dataloader: DataLoader,
|
|
39
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
40
|
+
config: PretrainedConfig,
|
|
41
|
+
skip_cases: int,
|
|
42
|
+
starting_epoch: int,
|
|
43
|
+
eval_models: List[EvalModel],
|
|
44
|
+
eric_timer: EricTimer,
|
|
45
|
+
precision_type: str,
|
|
46
|
+
) -> TrainResult:
|
|
47
|
+
best_tokenizer_config_saved = False
|
|
48
|
+
checkpoint_tokenizer_config_saved = False
|
|
49
|
+
first_epoch_train_iter = iter(train_dataloader)
|
|
50
|
+
|
|
51
|
+
if skip_cases:
|
|
52
|
+
try:
|
|
53
|
+
skip_batches = max(1, int(skip_cases / args.bs))
|
|
54
|
+
for _ in tqdm(range(skip_batches)):
|
|
55
|
+
next(first_epoch_train_iter, None)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
raise EricTrainError(f"Failed to skip initial training cases: {e}")
|
|
58
|
+
|
|
59
|
+
device = next(model.parameters()).device
|
|
60
|
+
|
|
61
|
+
gas = max(1, args.gas)
|
|
62
|
+
|
|
63
|
+
start_step = eric_tracker.state.current_step
|
|
64
|
+
|
|
65
|
+
for epoch in range(starting_epoch - 1, int(args.epochs)):
|
|
66
|
+
epoch_iter = (
|
|
67
|
+
first_epoch_train_iter
|
|
68
|
+
if epoch == starting_epoch - 1
|
|
69
|
+
else iter(train_dataloader)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
batch_idx = 0
|
|
73
|
+
while True:
|
|
74
|
+
if eric_tracker.state.current_step > train_steps:
|
|
75
|
+
break
|
|
76
|
+
|
|
77
|
+
with eric_timer.section("training_core", "iter"):
|
|
78
|
+
try:
|
|
79
|
+
batch = next(epoch_iter)
|
|
80
|
+
except StopIteration:
|
|
81
|
+
break
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
with eric_timer.section("training_core", "to_device"):
|
|
85
|
+
input_ids = batch["input_ids"].to(device, non_blocking=True)
|
|
86
|
+
attention_mask = batch["attention_mask"].to(
|
|
87
|
+
device, non_blocking=True
|
|
88
|
+
)
|
|
89
|
+
labels = batch["labels"].to(device, non_blocking=True)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise EricTrainError(f"Failed to extract inputs from batch: {e}")
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
with eric_timer.section("training_core", "forward"):
|
|
95
|
+
if device.type == "cuda":
|
|
96
|
+
with torch.autocast(
|
|
97
|
+
device_type="cuda",
|
|
98
|
+
enabled=precision_type != "fp32",
|
|
99
|
+
dtype=(
|
|
100
|
+
torch.float16
|
|
101
|
+
if precision_type == "fp16"
|
|
102
|
+
else torch.bfloat16
|
|
103
|
+
),
|
|
104
|
+
):
|
|
105
|
+
outputs = model(
|
|
106
|
+
input_ids=input_ids,
|
|
107
|
+
attention_mask=attention_mask,
|
|
108
|
+
labels=labels,
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
outputs = model(
|
|
112
|
+
input_ids=input_ids,
|
|
113
|
+
attention_mask=attention_mask,
|
|
114
|
+
labels=labels,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
with eric_timer.section("training_core", "get loss"):
|
|
118
|
+
raw_loss = outputs.loss.detach()
|
|
119
|
+
loss = outputs.loss / gas
|
|
120
|
+
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise EricTrainError(f"Model forward pass failed: {e}")
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
with eric_timer.section("training_core", "backwards"):
|
|
126
|
+
loss.backward()
|
|
127
|
+
except Exception as e:
|
|
128
|
+
raise EricTrainError(f"Backward pass failed: {e}")
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
do_step = ((batch_idx + 1) % gas == 0) or (
|
|
132
|
+
(batch_idx + 1) == len(train_dataloader)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if do_step:
|
|
136
|
+
if precision_type == "fp16":
|
|
137
|
+
# for now we don't clip gradients when using bf16 or fp32 since this
|
|
138
|
+
# operation takes a noticeable amount of time and they're stable enough.
|
|
139
|
+
with eric_timer.section("training_core", "clip gradients"):
|
|
140
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
141
|
+
|
|
142
|
+
with eric_timer.section("training_core", "optim.step()"):
|
|
143
|
+
optim.step()
|
|
144
|
+
with eric_timer.section("training_core", "lr_sched.step()"):
|
|
145
|
+
lr_sched.step()
|
|
146
|
+
with eric_timer.section("training_core", "optim.zero_grad()"):
|
|
147
|
+
optim.zero_grad()
|
|
148
|
+
eric_tracker.optim_step()
|
|
149
|
+
|
|
150
|
+
except Exception as e:
|
|
151
|
+
raise EricTrainError(f"optim or LR scheduler step failed: {e}")
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
do_eval = bool(eric_tracker.time_to_eval() and eval_cases)
|
|
155
|
+
|
|
156
|
+
if do_eval:
|
|
157
|
+
with eric_timer.section("training_extra", "eval loop"):
|
|
158
|
+
model.eval()
|
|
159
|
+
eval_result = eval_loop(
|
|
160
|
+
model,
|
|
161
|
+
eval_dataloader,
|
|
162
|
+
eval_cases,
|
|
163
|
+
args.eval_bs
|
|
164
|
+
if args.eval_bs
|
|
165
|
+
else args.bs,
|
|
166
|
+
eval_models=eval_models,
|
|
167
|
+
)
|
|
168
|
+
model.train()
|
|
169
|
+
|
|
170
|
+
global_eval_loss = float(eval_result.loss)
|
|
171
|
+
|
|
172
|
+
is_best_model = eric_tracker.set_eval_loss(global_eval_loss)
|
|
173
|
+
eric_tracker.set_metrics(eval_result.metrics)
|
|
174
|
+
|
|
175
|
+
if args.save_best and is_best_model:
|
|
176
|
+
with eric_timer.section("training_extra", "save best model"):
|
|
177
|
+
try:
|
|
178
|
+
path = eric_tracker.tracker_paths.best_model_path
|
|
179
|
+
model.save_pretrained(path)
|
|
180
|
+
except Exception as e:
|
|
181
|
+
path = eric_tracker.tracker_paths.best_model_path
|
|
182
|
+
raise EricTrainError(
|
|
183
|
+
f"Failed to save best model to: {path} | {e}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
state_path = (
|
|
188
|
+
eric_tracker.tracker_paths.best_model_state_path
|
|
189
|
+
)
|
|
190
|
+
with open(state_path, "w") as f:
|
|
191
|
+
json.dump(
|
|
192
|
+
eric_tracker.state.to_dict(), f, indent=2
|
|
193
|
+
)
|
|
194
|
+
except Exception as e:
|
|
195
|
+
state_path = (
|
|
196
|
+
eric_tracker.tracker_paths.best_model_state_path
|
|
197
|
+
)
|
|
198
|
+
raise EricTrainError(
|
|
199
|
+
f"Failed to save best model state to: {state_path} | {e}"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if not best_tokenizer_config_saved:
|
|
203
|
+
try:
|
|
204
|
+
tokenizer.save_pretrained(
|
|
205
|
+
eric_tracker.tracker_paths.best_model_path
|
|
206
|
+
)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
path = eric_tracker.tracker_paths.best_model_path
|
|
209
|
+
raise EricTrainError(
|
|
210
|
+
f"Failed to save tokenizer to: {path} | {e}"
|
|
211
|
+
)
|
|
212
|
+
try:
|
|
213
|
+
config.save_pretrained(
|
|
214
|
+
eric_tracker.tracker_paths.best_model_path
|
|
215
|
+
)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
path = eric_tracker.tracker_paths.best_model_path
|
|
218
|
+
raise EricTrainError(
|
|
219
|
+
f"Failed to save config to: {path} | {e}"
|
|
220
|
+
)
|
|
221
|
+
best_tokenizer_config_saved = True
|
|
222
|
+
except Exception as e:
|
|
223
|
+
raise EricTrainError(f"Evaluation loop failed: {e}")
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
do_ckpt = eric_tracker.time_to_checkpoint()
|
|
227
|
+
|
|
228
|
+
if do_ckpt:
|
|
229
|
+
with eric_timer.section("training_extra", "checkpoint"):
|
|
230
|
+
checkpoint_path = eric_tracker.tracker_paths.checkpoint_path
|
|
231
|
+
try:
|
|
232
|
+
model.save_pretrained(checkpoint_path)
|
|
233
|
+
except Exception as e:
|
|
234
|
+
raise EricTrainError(
|
|
235
|
+
f"Failed to save checkpoint model to: {checkpoint_path} | {e}"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if not checkpoint_tokenizer_config_saved:
|
|
239
|
+
try:
|
|
240
|
+
tokenizer.save_pretrained(checkpoint_path)
|
|
241
|
+
except Exception as e:
|
|
242
|
+
raise EricTrainError(
|
|
243
|
+
f"Failed to save tokenizer to: {checkpoint_path} |{e}"
|
|
244
|
+
)
|
|
245
|
+
try:
|
|
246
|
+
config.save_pretrained(checkpoint_path)
|
|
247
|
+
except Exception as e:
|
|
248
|
+
raise EricTrainError(
|
|
249
|
+
f"Failed to save config to: {checkpoint_path} | {e}"
|
|
250
|
+
)
|
|
251
|
+
checkpoint_tokenizer_config_saved = True
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
# Reloading the optimizer was causing problems. So for now we restart it.
|
|
255
|
+
#torch.save(
|
|
256
|
+
# optim.state_dict(),
|
|
257
|
+
# eric_tracker.tracker_paths.optim_path,
|
|
258
|
+
# )
|
|
259
|
+
|
|
260
|
+
torch.save(
|
|
261
|
+
lr_sched.state_dict(),
|
|
262
|
+
eric_tracker.tracker_paths.lr_sched_path,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
except Exception as e:
|
|
266
|
+
raise EricTrainError(
|
|
267
|
+
f"Failed to save lr scheduler state to: "
|
|
268
|
+
f"{eric_tracker.tracker_paths.optim_path} | {e}"
|
|
269
|
+
)
|
|
270
|
+
except Exception as e:
|
|
271
|
+
raise EricTrainError(str(e))
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
with eric_timer.section("training_extra", "logging"):
|
|
275
|
+
if (
|
|
276
|
+
eric_tracker.time_to_log()
|
|
277
|
+
or eric_tracker.time_to_eval()
|
|
278
|
+
or eric_tracker.time_to_checkpoint()
|
|
279
|
+
):
|
|
280
|
+
eric_timer.report()
|
|
281
|
+
|
|
282
|
+
num_tokens = int(attention_mask.sum().item())
|
|
283
|
+
eric_tracker.step(
|
|
284
|
+
raw_loss, lr_sched.get_last_lr()[0], num_tokens=num_tokens
|
|
285
|
+
)
|
|
286
|
+
except Exception as e:
|
|
287
|
+
raise EricTrainError(f"Tracker step failed: {e}")
|
|
288
|
+
|
|
289
|
+
debug_hook_post_checkpoint()
|
|
290
|
+
|
|
291
|
+
batch_idx += 1
|
|
292
|
+
|
|
293
|
+
eric_tracker.mark_epoch()
|
|
294
|
+
|
|
295
|
+
if eric_tracker.state.current_step >= train_steps:
|
|
296
|
+
break
|
|
297
|
+
|
|
298
|
+
debug_hook_steps({
|
|
299
|
+
"start_step": start_step,
|
|
300
|
+
"total_steps": eric_tracker.state.current_step
|
|
301
|
+
})
|
|
302
|
+
|
|
303
|
+
return TrainResult(
|
|
304
|
+
final_train_loss=eric_tracker.state.train_loss,
|
|
305
|
+
final_eval_loss=eric_tracker.state.eval_loss if eval_cases else None,
|
|
306
|
+
best_eval_loss=eric_tracker.state.best_eval_loss if eval_cases else None,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
debug_hook_post_checkpoint = DebugHook()
|
|
310
|
+
debug_hook_steps = DebugHook()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from erictransformer.utils.init import get_model_components, et_retrieve_tokenizer
|
|
2
|
+
from erictransformer.utils.init.get_device import et_get_device
|
|
3
|
+
from erictransformer.utils.init.get_logger import et_get_logger
|
|
4
|
+
from erictransformer.utils.timer import EricTimer
|
|
5
|
+
from erictransformer.utils.tok_data import (
|
|
6
|
+
get_procs,
|
|
7
|
+
prepare_output_locations,
|
|
8
|
+
resolve_input_files,
|
|
9
|
+
tok_dir_to_dataset,
|
|
10
|
+
write_details_file,
|
|
11
|
+
)
|
|
12
|
+
from erictransformer.utils.tok_data.save_tok_data import save_json_tok_data
|
|
13
|
+
from erictransformer.utils.train import (
|
|
14
|
+
create_tracker_dir,
|
|
15
|
+
get_num_training_steps,
|
|
16
|
+
get_optim,
|
|
17
|
+
get_precision,
|
|
18
|
+
get_tok_data,
|
|
19
|
+
make_dir,
|
|
20
|
+
resume_training,
|
|
21
|
+
)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from transformers import (
|
|
5
|
+
AutoConfig,
|
|
6
|
+
AutoModel,
|
|
7
|
+
AutoTokenizer,
|
|
8
|
+
PreTrainedModel,
|
|
9
|
+
PreTrainedTokenizerBase,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from erictransformer.exceptions import EricLoadModelError, EricLoadTokenizerError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_torch_dtype(torch_dtype: str) -> Union[torch.dtype, str]:
|
|
16
|
+
if torch_dtype == "fp32":
|
|
17
|
+
return torch.float32
|
|
18
|
+
elif torch_dtype == "fp16":
|
|
19
|
+
return torch.float16
|
|
20
|
+
elif torch_dtype == "bf16":
|
|
21
|
+
return torch.bfloat16
|
|
22
|
+
else:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
f"Invalid torch_dtype {torch_dtype}. Provide one of auto, fp32, fp16, or bf16."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_model(
|
|
29
|
+
model_name_path: Union[str, PreTrainedModel],
|
|
30
|
+
model_class: AutoModel,
|
|
31
|
+
trust_remote_code: bool,
|
|
32
|
+
precision,
|
|
33
|
+
) -> PreTrainedModel:
|
|
34
|
+
try:
|
|
35
|
+
if isinstance(model_name_path, PreTrainedModel):
|
|
36
|
+
return model_name_path
|
|
37
|
+
|
|
38
|
+
loaded_config = AutoConfig.from_pretrained(
|
|
39
|
+
model_name_path, trust_remote_code=trust_remote_code
|
|
40
|
+
)
|
|
41
|
+
torch_dtype = _get_torch_dtype(precision)
|
|
42
|
+
|
|
43
|
+
model = model_class.from_pretrained(
|
|
44
|
+
model_name_path,
|
|
45
|
+
config=loaded_config,
|
|
46
|
+
trust_remote_code=trust_remote_code,
|
|
47
|
+
dtype=torch_dtype,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return model
|
|
51
|
+
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise EricLoadModelError(f"Failed to load model from '{model_name_path}': {e}")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def et_retrieve_tokenizer(
|
|
57
|
+
tokenizer_path: str, trust_remote_code: bool
|
|
58
|
+
):
|
|
59
|
+
try:
|
|
60
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
61
|
+
tokenizer_path,
|
|
62
|
+
trust_remote_code=trust_remote_code,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return tokenizer
|
|
66
|
+
except Exception as e:
|
|
67
|
+
raise EricLoadTokenizerError(
|
|
68
|
+
f"Failed to load tokenizer from '{tokenizer_path}': {e}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_tokenizer(
|
|
73
|
+
model_name_path: str,
|
|
74
|
+
tokenizer_path: Union[str, PreTrainedTokenizerBase],
|
|
75
|
+
trust_remote_code: bool,
|
|
76
|
+
) -> PreTrainedTokenizerBase:
|
|
77
|
+
if (
|
|
78
|
+
tokenizer_path is None
|
|
79
|
+
): # by default we use the same value provided to model_name_path
|
|
80
|
+
tokenizer = et_retrieve_tokenizer(
|
|
81
|
+
tokenizer_path=model_name_path,
|
|
82
|
+
trust_remote_code=trust_remote_code,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
elif isinstance(tokenizer_path, str):
|
|
86
|
+
tokenizer = et_retrieve_tokenizer(
|
|
87
|
+
tokenizer_path=tokenizer_path,
|
|
88
|
+
trust_remote_code=trust_remote_code,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
elif isinstance(tokenizer_path, PreTrainedTokenizerBase):
|
|
92
|
+
# tokenizer provided directly
|
|
93
|
+
tokenizer = tokenizer_path
|
|
94
|
+
# maybe remove the following code since if an advanced user provides their own tokenizer they might want full control
|
|
95
|
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
|
96
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
"Invalid tokenizer_path. Provide a None, a string or AutoTokenizer"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if tokenizer.pad_token is None:
|
|
103
|
+
if tokenizer.eos_token is not None:
|
|
104
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
105
|
+
else:
|
|
106
|
+
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
|
107
|
+
|
|
108
|
+
return tokenizer
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_model_components(
|
|
112
|
+
model_name_path: Union[str, PreTrainedModel, None],
|
|
113
|
+
trust_remote_code: bool,
|
|
114
|
+
model_class: AutoModel,
|
|
115
|
+
tokenizer_path: Union[str, PreTrainedTokenizerBase],
|
|
116
|
+
precision: str,
|
|
117
|
+
):
|
|
118
|
+
tokenizer = get_tokenizer(
|
|
119
|
+
model_name_path=model_name_path,
|
|
120
|
+
tokenizer_path=tokenizer_path,
|
|
121
|
+
trust_remote_code=trust_remote_code,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if model_name_path is not None:
|
|
125
|
+
model = get_model(
|
|
126
|
+
model_name_path=model_name_path,
|
|
127
|
+
model_class=model_class,
|
|
128
|
+
trust_remote_code=trust_remote_code,
|
|
129
|
+
precision=precision,
|
|
130
|
+
)
|
|
131
|
+
config = model.config
|
|
132
|
+
else:
|
|
133
|
+
model = None
|
|
134
|
+
config = None
|
|
135
|
+
|
|
136
|
+
return config, tokenizer, model
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def get_model_components_tc(
|
|
140
|
+
model_name_path: Union[str, PreTrainedModel, None],
|
|
141
|
+
trust_remote_code: bool,
|
|
142
|
+
model_class: AutoModel,
|
|
143
|
+
tokenizer_path: Union[str, PreTrainedTokenizerBase],
|
|
144
|
+
precision: str,
|
|
145
|
+
labels: Optional[List[str]] = None,
|
|
146
|
+
):
|
|
147
|
+
|
|
148
|
+
config = AutoConfig.from_pretrained(
|
|
149
|
+
model_name_path, trust_remote_code=trust_remote_code
|
|
150
|
+
)
|
|
151
|
+
reset_labels = False # set to True if we should reset the labels to the size of num_labels
|
|
152
|
+
|
|
153
|
+
config_id2label = getattr(config, "id2label", None)
|
|
154
|
+
config_label2id = getattr(config, "label2id", None)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
if not config_id2label or not config_label2id:
|
|
158
|
+
reset_labels = True
|
|
159
|
+
elif labels is not None:
|
|
160
|
+
reset_labels = True
|
|
161
|
+
else:
|
|
162
|
+
id2label_length = len(config_id2label)
|
|
163
|
+
label2id_length = len(config_label2id)
|
|
164
|
+
if id2label_length != label2id_length:
|
|
165
|
+
reset_labels = True
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
if reset_labels:
|
|
169
|
+
if labels is None:
|
|
170
|
+
labels = ["LABEL_0", "LABEL_1"]
|
|
171
|
+
|
|
172
|
+
if labels:
|
|
173
|
+
config.id2label = {i: labels[i] for i in range(len(labels))}
|
|
174
|
+
else:
|
|
175
|
+
num_labels = 2
|
|
176
|
+
config.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
|
|
177
|
+
|
|
178
|
+
config.label2id = {v: k for k, v in config.id2label.items()}
|
|
179
|
+
|
|
180
|
+
config.num_labels = len(config.id2label)
|
|
181
|
+
|
|
182
|
+
torch_dtype = _get_torch_dtype(precision)
|
|
183
|
+
|
|
184
|
+
model = model_class.from_pretrained(
|
|
185
|
+
model_name_path,
|
|
186
|
+
config=config,
|
|
187
|
+
dtype=torch_dtype,
|
|
188
|
+
trust_remote_code=trust_remote_code,
|
|
189
|
+
ignore_mismatched_sizes=True,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
tokenizer = get_tokenizer(
|
|
193
|
+
model_name_path=model_name_path,
|
|
194
|
+
tokenizer_path=tokenizer_path,
|
|
195
|
+
trust_remote_code=trust_remote_code,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if tokenizer.pad_token is None:
|
|
199
|
+
if tokenizer.eos_token is not None:
|
|
200
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
201
|
+
else:
|
|
202
|
+
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
|
203
|
+
|
|
204
|
+
return config, tokenizer, model
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import device
|
|
3
|
+
|
|
4
|
+
from erictransformer.exceptions import EricDeviceError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def et_get_device() -> device:
|
|
8
|
+
try:
|
|
9
|
+
d = None
|
|
10
|
+
if torch.backends.mps.is_available():
|
|
11
|
+
if torch.backends.mps.is_built():
|
|
12
|
+
d = torch.device("mps")
|
|
13
|
+
|
|
14
|
+
if torch.cuda.is_available():
|
|
15
|
+
d = torch.device("cuda:0")
|
|
16
|
+
|
|
17
|
+
if not d:
|
|
18
|
+
d = torch.device("cpu")
|
|
19
|
+
|
|
20
|
+
return d
|
|
21
|
+
except Exception as e:
|
|
22
|
+
raise EricDeviceError(f"Device selection failed: {e}")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from logging import Logger
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def et_get_logger() -> Logger:
|
|
6
|
+
logger = logging.getLogger("erictransformer")
|
|
7
|
+
handler = logging.StreamHandler()
|
|
8
|
+
handler.addFilter(logging.Filter("erictransformer"))
|
|
9
|
+
logging.basicConfig(
|
|
10
|
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
11
|
+
datefmt="%m/%d/%Y %H:%M:%S",
|
|
12
|
+
level=logging.INFO,
|
|
13
|
+
handlers=[handler],
|
|
14
|
+
)
|
|
15
|
+
return logger
|