erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import List, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from datasets import Dataset, load_dataset
|
|
9
|
+
from huggingface_hub import HfApi
|
|
10
|
+
from torch.optim.lr_scheduler import ConstantLR
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
from transformers import (
|
|
13
|
+
AutoModel,
|
|
14
|
+
PretrainedConfig,
|
|
15
|
+
PreTrainedModel,
|
|
16
|
+
PreTrainedTokenizer,
|
|
17
|
+
PreTrainedTokenizerBase,
|
|
18
|
+
PreTrainedTokenizerFast,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from erictransformer.args import CallResult, CallArgs, EricEvalArgs, TokArgs, EricTrainArgs
|
|
22
|
+
from erictransformer.eval_models import EvalModel
|
|
23
|
+
from erictransformer.exceptions import (
|
|
24
|
+
EricDatasetError,
|
|
25
|
+
EricDeviceError,
|
|
26
|
+
EricIOError,
|
|
27
|
+
EricNoModelError,
|
|
28
|
+
EricPushError,
|
|
29
|
+
EricResumeError,
|
|
30
|
+
EricSaveError,
|
|
31
|
+
)
|
|
32
|
+
from erictransformer.eric_tracker import EricTracker
|
|
33
|
+
from erictransformer.loops import EvalResult, TrainResult, eval_loop, train_loop
|
|
34
|
+
from erictransformer.utils import (
|
|
35
|
+
EricTimer,
|
|
36
|
+
create_tracker_dir,
|
|
37
|
+
get_procs,
|
|
38
|
+
get_num_training_steps,
|
|
39
|
+
get_optim,
|
|
40
|
+
get_precision,
|
|
41
|
+
get_tok_data,
|
|
42
|
+
et_get_device,
|
|
43
|
+
et_get_logger,
|
|
44
|
+
prepare_output_locations,
|
|
45
|
+
resolve_input_files,
|
|
46
|
+
resume_training,
|
|
47
|
+
save_json_tok_data,
|
|
48
|
+
tok_dir_to_dataset,
|
|
49
|
+
write_details_file,
|
|
50
|
+
)
|
|
51
|
+
from erictransformer.validator import EvalValidator, TokValidator, TrainValidator
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class EricTransformerArgs:
|
|
56
|
+
# Arguments class for EricTransformer.__init__()
|
|
57
|
+
model_name: str
|
|
58
|
+
model_class: AutoModel
|
|
59
|
+
use_auth_token: Union[str, bool, None] = None
|
|
60
|
+
trust_remote_code: bool = False
|
|
61
|
+
tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast, None] = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class EricTransformer(ABC):
|
|
65
|
+
def __init__(self, eric_args: EricTransformerArgs):
|
|
66
|
+
self.logger = et_get_logger()
|
|
67
|
+
self.eric_args = eric_args
|
|
68
|
+
self.device = et_get_device()
|
|
69
|
+
self.precision_type = get_precision(device=self.device)
|
|
70
|
+
|
|
71
|
+
self.config, self.tokenizer, self.model = self._load_model_components()
|
|
72
|
+
|
|
73
|
+
if self.model is not None:
|
|
74
|
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
|
75
|
+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
|
76
|
+
|
|
77
|
+
self.logger.info("Using device: %s", self.device)
|
|
78
|
+
|
|
79
|
+
# These are set in child classes
|
|
80
|
+
self._data_collator = None
|
|
81
|
+
|
|
82
|
+
self._train_just_happened = True
|
|
83
|
+
|
|
84
|
+
self.eval_models = self._get_default_eval_models()
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def __call__(
|
|
88
|
+
self, text: str, args: CallArgs = CallArgs()
|
|
89
|
+
) -> List[CallResult]:
|
|
90
|
+
raise NotImplementedError()
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def _tok_function(
|
|
94
|
+
self, raw_dataset, args: TokArgs, file_type: str, procs: int = 1
|
|
95
|
+
) -> Dataset:
|
|
96
|
+
raise NotImplementedError()
|
|
97
|
+
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def _load_model_components(
|
|
100
|
+
self,
|
|
101
|
+
) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def _format_tokenized_example(self, example: dict) -> dict:
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def _get_default_eval_models(self) -> List[EvalModel]:
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def _get_readme(self, repo_id: str) -> str:
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def _prep_model(self):
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
def train(
|
|
121
|
+
self,
|
|
122
|
+
train_path: str = "",
|
|
123
|
+
args: EricTrainArgs = EricTrainArgs(),
|
|
124
|
+
eval_path: str = "",
|
|
125
|
+
*,
|
|
126
|
+
resume_path: str = "", # a path to a dir
|
|
127
|
+
) -> TrainResult:
|
|
128
|
+
out_dir = create_tracker_dir(args.out_dir, "train", args.run_name)
|
|
129
|
+
|
|
130
|
+
timer_dir = os.path.join(out_dir, "time")
|
|
131
|
+
|
|
132
|
+
eric_timer = EricTimer(out_dir=timer_dir)
|
|
133
|
+
|
|
134
|
+
eric_train_validator = TrainValidator(
|
|
135
|
+
self.logger,
|
|
136
|
+
train_path=train_path,
|
|
137
|
+
args=args,
|
|
138
|
+
eval_path=eval_path,
|
|
139
|
+
resume_path=resume_path,
|
|
140
|
+
model=self.model,
|
|
141
|
+
)
|
|
142
|
+
# the validator may alter some of the parameters
|
|
143
|
+
args = eric_train_validator.args
|
|
144
|
+
|
|
145
|
+
tracker_state = None
|
|
146
|
+
if resume_path:
|
|
147
|
+
with eric_timer.section("resume", "load resume path"):
|
|
148
|
+
(
|
|
149
|
+
tracker_state,
|
|
150
|
+
args_dict,
|
|
151
|
+
model_tokenizer_path,
|
|
152
|
+
lr_sched_path,
|
|
153
|
+
) = resume_training(resume_path)
|
|
154
|
+
args = EricTrainArgs(**args_dict)
|
|
155
|
+
|
|
156
|
+
self.config, self.tokenizer, self.model = self._load_model_components()
|
|
157
|
+
|
|
158
|
+
if eric_train_validator.train_source == "file":
|
|
159
|
+
with eric_timer.section("tokenize", "tokenize_train"):
|
|
160
|
+
current_train_tok_dir = self._tokenize_input_file(
|
|
161
|
+
train_path, out_dir, "train"
|
|
162
|
+
)
|
|
163
|
+
self.logger.info(
|
|
164
|
+
f"Train data has been tokenized and saved to {current_train_tok_dir}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
else: # has to be "train_tok_dir"
|
|
168
|
+
current_train_tok_dir = train_path
|
|
169
|
+
|
|
170
|
+
with eric_timer.section("tokenize", "load_train_data"):
|
|
171
|
+
train_dataloader, num_train_cases = get_tok_data(
|
|
172
|
+
current_train_tok_dir,
|
|
173
|
+
args.seed,
|
|
174
|
+
args.bs,
|
|
175
|
+
self._data_collator,
|
|
176
|
+
self.device,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
skip_eval = False
|
|
180
|
+
if eric_train_validator.eval_source == "file":
|
|
181
|
+
with eric_timer.section("tokenize", "tokenize_eval"):
|
|
182
|
+
current_eval_tok_dir = self._tokenize_input_file(
|
|
183
|
+
eval_path, out_dir, "eval"
|
|
184
|
+
)
|
|
185
|
+
self.logger.info(
|
|
186
|
+
f"Eval data has been tokenized and saved to {current_eval_tok_dir}"
|
|
187
|
+
)
|
|
188
|
+
elif eric_train_validator.eval_source == "folder":
|
|
189
|
+
current_eval_tok_dir = eval_path
|
|
190
|
+
else:
|
|
191
|
+
self.logger.info(
|
|
192
|
+
"No evaluating data will be used. Provide eval_tok_dir or eval_filepath"
|
|
193
|
+
)
|
|
194
|
+
skip_eval = True
|
|
195
|
+
|
|
196
|
+
if not skip_eval:
|
|
197
|
+
with eric_timer.section("tokenize", "load_eval_data"):
|
|
198
|
+
eval_dataloader, num_eval_cases = get_tok_data(
|
|
199
|
+
current_eval_tok_dir,
|
|
200
|
+
args.seed,
|
|
201
|
+
args.eval_bs if args.eval_bs else args.bs,
|
|
202
|
+
self._data_collator,
|
|
203
|
+
self.device,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
num_eval_cases = 0
|
|
207
|
+
eval_dataloader = None
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
self.model.to(self.device)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
raise EricDeviceError(f"Failed to move model to {e}")
|
|
213
|
+
|
|
214
|
+
train_steps = get_num_training_steps(
|
|
215
|
+
train_cases=num_train_cases,
|
|
216
|
+
num_devices=1,
|
|
217
|
+
epochs=args.epochs,
|
|
218
|
+
gas=args.gas,
|
|
219
|
+
bs=args.bs,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
eric_tracker = EricTracker(
|
|
223
|
+
args, train_steps, out_dir, tracker_state if resume_path else None
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
optim = get_optim(args, self.model, self.logger)
|
|
227
|
+
|
|
228
|
+
lr_sched = None
|
|
229
|
+
if args.lr_sched == "warmup_then_decay":
|
|
230
|
+
if train_steps < 8:
|
|
231
|
+
self.logger.info(
|
|
232
|
+
"You need to have at least 8 steps to use the 'warmup_then_decay' lr_sched. Falling back to 'constant'"
|
|
233
|
+
)
|
|
234
|
+
else:
|
|
235
|
+
num_warmup_steps = math.ceil(train_steps / 10) # 10% warmup
|
|
236
|
+
|
|
237
|
+
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
|
|
238
|
+
optim,
|
|
239
|
+
start_factor=0.1, # 10% of the steps are for the warmup phase
|
|
240
|
+
end_factor=1.0,
|
|
241
|
+
total_iters=num_warmup_steps,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
245
|
+
optim,
|
|
246
|
+
T_max=train_steps - num_warmup_steps,
|
|
247
|
+
eta_min=0.1 * args.lr,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
lr_sched = torch.optim.lr_scheduler.SequentialLR(
|
|
251
|
+
optim,
|
|
252
|
+
schedulers=[warmup_scheduler, cosine_scheduler],
|
|
253
|
+
milestones=[num_warmup_steps],
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if lr_sched is None:
|
|
257
|
+
lr_sched = ConstantLR(optim, factor=1, total_iters=1)
|
|
258
|
+
|
|
259
|
+
skip_cases = 0
|
|
260
|
+
starting_epoch = 1
|
|
261
|
+
|
|
262
|
+
if resume_path:
|
|
263
|
+
cases_per_step = args.bs * args.gas
|
|
264
|
+
if tracker_state["last_checkpoint_step"] == train_steps:
|
|
265
|
+
raise EricResumeError(
|
|
266
|
+
"You provided a path to an already completed training run"
|
|
267
|
+
)
|
|
268
|
+
skip_cases = tracker_state["last_checkpoint_step"] * cases_per_step
|
|
269
|
+
current_epoch = tracker_state["epoch"]
|
|
270
|
+
starting_epoch = current_epoch
|
|
271
|
+
if current_epoch > 1:
|
|
272
|
+
skip_cases = skip_cases - (current_epoch - 1) * num_train_cases
|
|
273
|
+
try:
|
|
274
|
+
with eric_timer.section("resume", "load lr scheduler"):
|
|
275
|
+
lr_sched.load_state_dict(torch.load(lr_sched_path))
|
|
276
|
+
# for now we don't resume the optimizer's state as it was causing problems.
|
|
277
|
+
except Exception as e:
|
|
278
|
+
raise EricResumeError(
|
|
279
|
+
f"Could not load lr scheduler state from resume path: {e}"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
self.logger.info(f"Model on {next(self.model.parameters()).device} ")
|
|
283
|
+
|
|
284
|
+
self.model.train()
|
|
285
|
+
|
|
286
|
+
train_loop_result: TrainResult = train_loop(
|
|
287
|
+
args=args,
|
|
288
|
+
train_dataloader=train_dataloader,
|
|
289
|
+
eric_tracker=eric_tracker,
|
|
290
|
+
train_steps=train_steps,
|
|
291
|
+
model=self.model,
|
|
292
|
+
optim=optim,
|
|
293
|
+
lr_sched=lr_sched,
|
|
294
|
+
eval_cases=num_eval_cases,
|
|
295
|
+
eval_dataloader=eval_dataloader,
|
|
296
|
+
tokenizer=self.tokenizer,
|
|
297
|
+
config=self.config,
|
|
298
|
+
skip_cases=skip_cases,
|
|
299
|
+
starting_epoch=starting_epoch,
|
|
300
|
+
eval_models=self.eval_models,
|
|
301
|
+
eric_timer=eric_timer,
|
|
302
|
+
precision_type=self.precision_type,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
eric_tracker.close()
|
|
306
|
+
|
|
307
|
+
self._train_just_happened = True
|
|
308
|
+
eric_timer.report()
|
|
309
|
+
return train_loop_result
|
|
310
|
+
|
|
311
|
+
def eval(self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()) -> EvalResult:
|
|
312
|
+
eval_validator = EvalValidator(
|
|
313
|
+
model=self.model,
|
|
314
|
+
eval_path=eval_path,
|
|
315
|
+
args=args,
|
|
316
|
+
out_dir=args.out_dir,
|
|
317
|
+
logger=self.logger,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if eval_validator.eval_source == "file":
|
|
321
|
+
out_directory = create_tracker_dir(
|
|
322
|
+
args.out_dir, "eval", args.run_name
|
|
323
|
+
)
|
|
324
|
+
current_eval_tok_dir = self._tokenize_input_file(
|
|
325
|
+
eval_path, out_directory, "eval", in_eval=True
|
|
326
|
+
)
|
|
327
|
+
self.logger.info(
|
|
328
|
+
f"Eval data has been tokenized and saved to {current_eval_tok_dir}"
|
|
329
|
+
)
|
|
330
|
+
else: # has to be train_path:
|
|
331
|
+
current_eval_tok_dir = eval_path
|
|
332
|
+
|
|
333
|
+
dataloader, tok_data_len = get_tok_data(
|
|
334
|
+
current_eval_tok_dir,
|
|
335
|
+
args.seed,
|
|
336
|
+
args.bs,
|
|
337
|
+
self._data_collator,
|
|
338
|
+
self.device,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
self.model.to(self.device)
|
|
343
|
+
except Exception as e:
|
|
344
|
+
raise EricDeviceError(f"Failed to move model to {e}")
|
|
345
|
+
|
|
346
|
+
self.model.eval()
|
|
347
|
+
return eval_loop(
|
|
348
|
+
self.model,
|
|
349
|
+
dataloader,
|
|
350
|
+
tok_data_len,
|
|
351
|
+
args.bs,
|
|
352
|
+
eval_models=self.eval_models,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def tok(self, path: str, out_dir: str, args: TokArgs = TokArgs()):
|
|
356
|
+
_ = TokValidator(
|
|
357
|
+
input_data=path, out_dir=out_dir, args=args
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
file_infos = resolve_input_files(path)
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
364
|
+
except Exception as e:
|
|
365
|
+
raise EricIOError(f"Failed to make directory: {out_dir}. {e}")
|
|
366
|
+
output_paths, output_data_loc = prepare_output_locations(
|
|
367
|
+
out_dir, args.shards
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
num_cases = 0
|
|
371
|
+
p_num_files = tqdm(
|
|
372
|
+
total=len(file_infos), desc="# tokenizing shard #", position=0
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
for file, file_type in file_infos:
|
|
376
|
+
if args.max_cases >0 and num_cases >= args.max_cases:
|
|
377
|
+
break
|
|
378
|
+
try:
|
|
379
|
+
if file_type == "text":
|
|
380
|
+
dataset = load_dataset(
|
|
381
|
+
file_type,
|
|
382
|
+
data_files=[file],
|
|
383
|
+
split="train",
|
|
384
|
+
sample_by="document",
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
dataset = load_dataset(file_type, data_files=[file], split="train")
|
|
388
|
+
|
|
389
|
+
except Exception as e:
|
|
390
|
+
raise (
|
|
391
|
+
EricDatasetError(f"Error loading {file} of type {file_type}: {e}")
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Trim dataset if necessary
|
|
395
|
+
remaining = (
|
|
396
|
+
args.max_cases - num_cases
|
|
397
|
+
if args.max_cases > 0
|
|
398
|
+
else len(dataset)
|
|
399
|
+
)
|
|
400
|
+
trimmed_count = min(len(dataset), remaining)
|
|
401
|
+
dataset = dataset.select(range(trimmed_count))
|
|
402
|
+
procs = get_procs(args.procs)
|
|
403
|
+
# Tokenize trimmed data
|
|
404
|
+
tokenized = self._tok_function(
|
|
405
|
+
dataset, args=args, file_type=file_type, procs=procs
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
save_json_tok_data(
|
|
409
|
+
tokenized,
|
|
410
|
+
output_data_loc,
|
|
411
|
+
args.shards,
|
|
412
|
+
self._format_tokenized_example,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
num_cases += len(tokenized)
|
|
416
|
+
|
|
417
|
+
p_num_files.update(1)
|
|
418
|
+
|
|
419
|
+
write_details_file(out_dir, num_cases, output_paths)
|
|
420
|
+
|
|
421
|
+
for f in output_data_loc:
|
|
422
|
+
f.close()
|
|
423
|
+
|
|
424
|
+
def save(self, path: str):
|
|
425
|
+
if self.model is None:
|
|
426
|
+
raise EricNoModelError(
|
|
427
|
+
f"No model found. Provide a model_name or PreTrainedModel when instantiating {self.__class__.__name__}."
|
|
428
|
+
)
|
|
429
|
+
try:
|
|
430
|
+
self.model.save_pretrained(path)
|
|
431
|
+
except Exception as e:
|
|
432
|
+
raise EricSaveError(f"Error saving model to {path}: {e}")
|
|
433
|
+
try:
|
|
434
|
+
self.tokenizer.save_pretrained(path)
|
|
435
|
+
except Exception as e:
|
|
436
|
+
raise EricSaveError(f"Error saving tokenizer to {path}: {e}")
|
|
437
|
+
try:
|
|
438
|
+
self.config.save_pretrained(path)
|
|
439
|
+
except Exception as e:
|
|
440
|
+
raise EricSaveError(f"Error saving config to {path}: {e}")
|
|
441
|
+
|
|
442
|
+
self._prep_model() # Reset self.model.generation_config with the Eric Transformer values for CHAT, GEN and TT
|
|
443
|
+
|
|
444
|
+
def push(self, repo_id: str, private: bool = True):
|
|
445
|
+
if self.model is None:
|
|
446
|
+
raise EricNoModelError(
|
|
447
|
+
f"No model found. Provide a model_name or PreTrainedModel when instantiating {self.__class__.__name__}."
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
api = HfApi()
|
|
451
|
+
try:
|
|
452
|
+
api.create_repo(repo_id, exist_ok=True, private=private)
|
|
453
|
+
except Exception as e:
|
|
454
|
+
self.logger.warning(f"Could not crate repo {e}")
|
|
455
|
+
return
|
|
456
|
+
try:
|
|
457
|
+
has_readme = api.file_exists(repo_id, "README.md")
|
|
458
|
+
except Exception as e:
|
|
459
|
+
self.logger.warning(f"Could not info: {e}")
|
|
460
|
+
return
|
|
461
|
+
|
|
462
|
+
if not has_readme:
|
|
463
|
+
readme_text = self._get_readme(repo_id)
|
|
464
|
+
try:
|
|
465
|
+
self.logger.info("Pushing README...")
|
|
466
|
+
|
|
467
|
+
api.upload_file(
|
|
468
|
+
path_or_fileobj=readme_text.encode("utf-8"),
|
|
469
|
+
path_in_repo="README.md",
|
|
470
|
+
repo_id=repo_id,
|
|
471
|
+
repo_type="model",
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
except Exception as e:
|
|
475
|
+
# Don’t fail the whole push if README upload fails; just warn.
|
|
476
|
+
self.logger.warning(f"Error pushing README: {e}")
|
|
477
|
+
|
|
478
|
+
try:
|
|
479
|
+
self.logger.info("Pushing model...")
|
|
480
|
+
self.model.push_to_hub(
|
|
481
|
+
repo_id,
|
|
482
|
+
private=private,
|
|
483
|
+
commit_message="Uploaded model from Eric Transformer",
|
|
484
|
+
)
|
|
485
|
+
except Exception as e:
|
|
486
|
+
raise EricPushError(f"Error pushing model: {e}")
|
|
487
|
+
try:
|
|
488
|
+
self.logger.info("Pushing tokenizer...")
|
|
489
|
+
self.tokenizer.push_to_hub(
|
|
490
|
+
repo_id,
|
|
491
|
+
private=private,
|
|
492
|
+
commit_message="Uploaded tokenizer from Eric Transformer",
|
|
493
|
+
)
|
|
494
|
+
except Exception as e:
|
|
495
|
+
raise EricPushError(f"Error pushing tokenizer: {e}")
|
|
496
|
+
try:
|
|
497
|
+
self.logger.info("Pushing config...")
|
|
498
|
+
self.config.push_to_hub(
|
|
499
|
+
repo_id,
|
|
500
|
+
private=private,
|
|
501
|
+
commit_message="Uploaded config from Eric Transformer",
|
|
502
|
+
)
|
|
503
|
+
except Exception as e:
|
|
504
|
+
raise EricPushError(f"Error pushing config: {e}")
|
|
505
|
+
|
|
506
|
+
def _tokenize_input_file(
|
|
507
|
+
self, train_path: str, out_dir: str, label: str, in_eval: bool = False
|
|
508
|
+
) -> str:
|
|
509
|
+
dir_name = f"tok_{label}_data" if in_eval else f"data/tok_{label}_data"
|
|
510
|
+
tok_dir = os.path.join(out_dir, dir_name)
|
|
511
|
+
try:
|
|
512
|
+
os.makedirs(tok_dir, exist_ok=True)
|
|
513
|
+
except Exception as e:
|
|
514
|
+
raise EricIOError(f"error making directory ({tok_dir}): {e}")
|
|
515
|
+
|
|
516
|
+
self.tok(train_path, tok_dir)
|
|
517
|
+
self.logger.info(
|
|
518
|
+
f"{label.capitalize()} data has been tokenized and saved to {tok_dir}"
|
|
519
|
+
)
|
|
520
|
+
return tok_dir
|
|
521
|
+
|
|
522
|
+
def _get_model_ready_inference(self):
|
|
523
|
+
if self._train_just_happened:
|
|
524
|
+
self.logger.info(f"Moving model to {self.device}")
|
|
525
|
+
try:
|
|
526
|
+
self.model.to(self.device)
|
|
527
|
+
except Exception as e:
|
|
528
|
+
raise EricDeviceError(f"Failed to move model to device: {e}")
|
|
529
|
+
|
|
530
|
+
self._train_just_happened = False
|
|
531
|
+
|
|
532
|
+
@staticmethod
|
|
533
|
+
def tok_dir_to_dataset(tok_dir: str) -> Dataset:
|
|
534
|
+
return tok_dir_to_dataset(tok_dir)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from erictransformer.eval_models.eval_model import EvalModel, TCAccuracyEvalModel
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from transformers.modeling_outputs import ModelOutput
|
|
6
|
+
|
|
7
|
+
from erictransformer.exceptions import EricEvalModelError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EvalModel(ABC):
|
|
11
|
+
def __init__(self, name: str):
|
|
12
|
+
self.name = name
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def reset(self) -> None:
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def __call__(self, batch: Dict[str, torch.Tensor], outputs: ModelOutput) -> None:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def result(self) -> Optional[float]:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def _confirm_batch_params(
|
|
28
|
+
self, batch: Dict[str, torch.Tensor], outputs: ModelOutput
|
|
29
|
+
) -> None:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TCAccuracyEvalModel(EvalModel):
|
|
34
|
+
def __init__(self):
|
|
35
|
+
self.correct_preds = 0
|
|
36
|
+
self.total_preds = 0
|
|
37
|
+
self.name = "accuracy"
|
|
38
|
+
super().__init__(self.name)
|
|
39
|
+
|
|
40
|
+
def reset(self) -> None:
|
|
41
|
+
self.correct_preds = 0
|
|
42
|
+
self.total_preds = 0
|
|
43
|
+
|
|
44
|
+
def __call__(self, batch: Dict[str, torch.Tensor], outputs: ModelOutput) -> None:
|
|
45
|
+
try:
|
|
46
|
+
self._confirm_batch_params(batch, outputs)
|
|
47
|
+
logits = outputs.logits
|
|
48
|
+
labels = batch["labels"]
|
|
49
|
+
preds = torch.argmax(logits, dim=-1)
|
|
50
|
+
self.correct_preds += (preds == labels).sum().item()
|
|
51
|
+
self.total_preds += labels.size(0)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise EricEvalModelError(f"error calling {self.__class__.__name__}: {e}")
|
|
54
|
+
|
|
55
|
+
def result(self) -> Optional[float]:
|
|
56
|
+
return self.correct_preds / self.total_preds if self.total_preds > 0 else None
|
|
57
|
+
|
|
58
|
+
def _confirm_batch_params(
|
|
59
|
+
self, batch: Dict[str, torch.Tensor], outputs: ModelOutput
|
|
60
|
+
) -> None:
|
|
61
|
+
if "labels" not in batch:
|
|
62
|
+
raise EricEvalModelError("Batch must contain a 'labels' key.")
|
|
63
|
+
if not isinstance(batch["labels"], torch.Tensor):
|
|
64
|
+
raise EricEvalModelError("batch['labels'] must be a torch.Tensor.")
|
|
65
|
+
if not hasattr(outputs, "logits"):
|
|
66
|
+
raise EricEvalModelError("ModelOutput must have a 'logits' attribute.")
|
|
67
|
+
if not isinstance(outputs.logits, torch.Tensor):
|
|
68
|
+
raise EricEvalModelError("outputs.logits must be a torch.Tensor.")
|
|
69
|
+
|
|
70
|
+
bs_logits = outputs.logits.shape[0]
|
|
71
|
+
bs_labels = batch["labels"].shape[0]
|
|
72
|
+
if bs_logits != bs_labels:
|
|
73
|
+
raise EricEvalModelError(
|
|
74
|
+
f"Mismatch in batch size: logits has {bs_logits}, labels has {bs_labels}"
|
|
75
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from erictransformer.exceptions.eric_exceptions import (
|
|
2
|
+
EricChatTemplateError,
|
|
3
|
+
EricDatasetError,
|
|
4
|
+
EricDeviceError,
|
|
5
|
+
EricEvalError,
|
|
6
|
+
EricEvalModelError,
|
|
7
|
+
EricInferenceError,
|
|
8
|
+
EricInputError,
|
|
9
|
+
EricIOError,
|
|
10
|
+
EricLoadModelError,
|
|
11
|
+
EricLoadPipelineError,
|
|
12
|
+
EricLoadTokenizerError,
|
|
13
|
+
EricNoModelError,
|
|
14
|
+
EricPushError,
|
|
15
|
+
EricResumeError,
|
|
16
|
+
EricSaveError,
|
|
17
|
+
EricTokenizationError,
|
|
18
|
+
EricTrainError,
|
|
19
|
+
)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
class EricTransformerError(Exception):
|
|
2
|
+
pass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class EricInputError(EricTransformerError):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EricLoadModelError(EricTransformerError):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EricLoadTokenizerError(EricTransformerError):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EricLoadPipelineError(EricTransformerError):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EricResumeError(EricTransformerError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class EricInferenceError(EricTransformerError):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class EricDeviceError(EricTransformerError):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EricIOError(EricTransformerError):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class EricTokenizationError(EricTransformerError):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class EricDatasetError(EricTransformerError):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EricChatTemplateError(EricTransformerError):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class EricEvalModelError(EricTransformerError):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class EricSaveError(EricTransformerError):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class EricNoModelError(EricTransformerError):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class EricPushError(EricTransformerError):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class EricPlotError(EricTransformerError):
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class EricEvalError(EricTransformerError):
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class EricTrainError(EricTransformerError):
|
|
74
|
+
pass
|