erictransformer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. erictransformer/__init__.py +44 -0
  2. erictransformer/args/__init__.py +7 -0
  3. erictransformer/args/eric_args.py +50 -0
  4. erictransformer/eric_tasks/__init__.py +47 -0
  5. erictransformer/eric_tasks/args/__init__.py +16 -0
  6. erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
  7. erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
  8. erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
  9. erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
  10. erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
  11. erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
  12. erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
  13. erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
  14. erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
  15. erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
  16. erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
  17. erictransformer/eric_tasks/chat_templates/convert.py +67 -0
  18. erictransformer/eric_tasks/eric_chat.py +369 -0
  19. erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
  20. erictransformer/eric_tasks/eric_generation.py +243 -0
  21. erictransformer/eric_tasks/eric_text_classification.py +231 -0
  22. erictransformer/eric_tasks/eric_text_to_text.py +283 -0
  23. erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
  24. erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
  25. erictransformer/eric_tasks/misc/__init__.py +11 -0
  26. erictransformer/eric_tasks/misc/call_utils.py +69 -0
  27. erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
  28. erictransformer/eric_tasks/misc/rag.py +17 -0
  29. erictransformer/eric_tasks/results/__init__.py +6 -0
  30. erictransformer/eric_tasks/results/call_results.py +30 -0
  31. erictransformer/eric_tasks/tok/__init__.py +0 -0
  32. erictransformer/eric_tasks/tok/tok_functions.py +118 -0
  33. erictransformer/eric_tracker/__init__.py +1 -0
  34. erictransformer/eric_tracker/eric_tracker.py +256 -0
  35. erictransformer/eric_tracker/save_plot.py +422 -0
  36. erictransformer/eric_transformer.py +534 -0
  37. erictransformer/eval_models/__init__.py +1 -0
  38. erictransformer/eval_models/eval_model.py +75 -0
  39. erictransformer/exceptions/__init__.py +19 -0
  40. erictransformer/exceptions/eric_exceptions.py +74 -0
  41. erictransformer/loops/__init__.py +2 -0
  42. erictransformer/loops/eval_loop.py +111 -0
  43. erictransformer/loops/train_loop.py +310 -0
  44. erictransformer/utils/__init__.py +21 -0
  45. erictransformer/utils/init/__init__.py +5 -0
  46. erictransformer/utils/init/get_components.py +204 -0
  47. erictransformer/utils/init/get_device.py +22 -0
  48. erictransformer/utils/init/get_logger.py +15 -0
  49. erictransformer/utils/load_from_repo_or_path.py +14 -0
  50. erictransformer/utils/test/__init__.py +1 -0
  51. erictransformer/utils/test/debug_hook.py +20 -0
  52. erictransformer/utils/timer/__init__.py +1 -0
  53. erictransformer/utils/timer/eric_timer.py +145 -0
  54. erictransformer/utils/tok_data/__init__.py +8 -0
  55. erictransformer/utils/tok_data/num_proc.py +15 -0
  56. erictransformer/utils/tok_data/save_tok_data.py +36 -0
  57. erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
  58. erictransformer/utils/tok_data/tok_helpers.py +79 -0
  59. erictransformer/utils/train/__init__.py +6 -0
  60. erictransformer/utils/train/confirm_optimizer.py +18 -0
  61. erictransformer/utils/train/create_dir.py +72 -0
  62. erictransformer/utils/train/get_num_training_steps.py +15 -0
  63. erictransformer/utils/train/get_precision.py +22 -0
  64. erictransformer/utils/train/get_tok_data.py +105 -0
  65. erictransformer/utils/train/resume.py +62 -0
  66. erictransformer/validator/__init__.py +11 -0
  67. erictransformer/validator/eric/__init__.py +2 -0
  68. erictransformer/validator/eric/eval_validator.py +75 -0
  69. erictransformer/validator/eric/train_validator.py +143 -0
  70. erictransformer/validator/eric_validator.py +10 -0
  71. erictransformer/validator/tasks/__init__.py +5 -0
  72. erictransformer/validator/tasks/chat_validator.py +28 -0
  73. erictransformer/validator/tasks/gen_validator.py +28 -0
  74. erictransformer/validator/tasks/task_validator.py +54 -0
  75. erictransformer/validator/tasks/tc_validator.py +45 -0
  76. erictransformer/validator/tasks/tt_validator.py +28 -0
  77. erictransformer/validator/tok/__init__.py +1 -0
  78. erictransformer/validator/tok/tok_validator.py +23 -0
  79. erictransformer-0.0.1.dist-info/METADATA +72 -0
  80. erictransformer-0.0.1.dist-info/RECORD +83 -0
  81. erictransformer-0.0.1.dist-info/WHEEL +5 -0
  82. erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
  83. erictransformer-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,534 @@
1
+ import math
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import List, Tuple, Union
6
+
7
+ import torch
8
+ from datasets import Dataset, load_dataset
9
+ from huggingface_hub import HfApi
10
+ from torch.optim.lr_scheduler import ConstantLR
11
+ from tqdm.auto import tqdm
12
+ from transformers import (
13
+ AutoModel,
14
+ PretrainedConfig,
15
+ PreTrainedModel,
16
+ PreTrainedTokenizer,
17
+ PreTrainedTokenizerBase,
18
+ PreTrainedTokenizerFast,
19
+ )
20
+
21
+ from erictransformer.args import CallResult, CallArgs, EricEvalArgs, TokArgs, EricTrainArgs
22
+ from erictransformer.eval_models import EvalModel
23
+ from erictransformer.exceptions import (
24
+ EricDatasetError,
25
+ EricDeviceError,
26
+ EricIOError,
27
+ EricNoModelError,
28
+ EricPushError,
29
+ EricResumeError,
30
+ EricSaveError,
31
+ )
32
+ from erictransformer.eric_tracker import EricTracker
33
+ from erictransformer.loops import EvalResult, TrainResult, eval_loop, train_loop
34
+ from erictransformer.utils import (
35
+ EricTimer,
36
+ create_tracker_dir,
37
+ get_procs,
38
+ get_num_training_steps,
39
+ get_optim,
40
+ get_precision,
41
+ get_tok_data,
42
+ et_get_device,
43
+ et_get_logger,
44
+ prepare_output_locations,
45
+ resolve_input_files,
46
+ resume_training,
47
+ save_json_tok_data,
48
+ tok_dir_to_dataset,
49
+ write_details_file,
50
+ )
51
+ from erictransformer.validator import EvalValidator, TokValidator, TrainValidator
52
+
53
+
54
+ @dataclass
55
+ class EricTransformerArgs:
56
+ # Arguments class for EricTransformer.__init__()
57
+ model_name: str
58
+ model_class: AutoModel
59
+ use_auth_token: Union[str, bool, None] = None
60
+ trust_remote_code: bool = False
61
+ tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast, None] = None
62
+
63
+
64
+ class EricTransformer(ABC):
65
+ def __init__(self, eric_args: EricTransformerArgs):
66
+ self.logger = et_get_logger()
67
+ self.eric_args = eric_args
68
+ self.device = et_get_device()
69
+ self.precision_type = get_precision(device=self.device)
70
+
71
+ self.config, self.tokenizer, self.model = self._load_model_components()
72
+
73
+ if self.model is not None:
74
+ self.model.resize_token_embeddings(len(self.tokenizer))
75
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
76
+
77
+ self.logger.info("Using device: %s", self.device)
78
+
79
+ # These are set in child classes
80
+ self._data_collator = None
81
+
82
+ self._train_just_happened = True
83
+
84
+ self.eval_models = self._get_default_eval_models()
85
+
86
+ @abstractmethod
87
+ def __call__(
88
+ self, text: str, args: CallArgs = CallArgs()
89
+ ) -> List[CallResult]:
90
+ raise NotImplementedError()
91
+
92
+ @abstractmethod
93
+ def _tok_function(
94
+ self, raw_dataset, args: TokArgs, file_type: str, procs: int = 1
95
+ ) -> Dataset:
96
+ raise NotImplementedError()
97
+
98
+ @abstractmethod
99
+ def _load_model_components(
100
+ self,
101
+ ) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
102
+ pass
103
+
104
+ @abstractmethod
105
+ def _format_tokenized_example(self, example: dict) -> dict:
106
+ pass
107
+
108
+ @abstractmethod
109
+ def _get_default_eval_models(self) -> List[EvalModel]:
110
+ pass
111
+
112
+ @abstractmethod
113
+ def _get_readme(self, repo_id: str) -> str:
114
+ pass
115
+
116
+ @abstractmethod
117
+ def _prep_model(self):
118
+ pass
119
+
120
+ def train(
121
+ self,
122
+ train_path: str = "",
123
+ args: EricTrainArgs = EricTrainArgs(),
124
+ eval_path: str = "",
125
+ *,
126
+ resume_path: str = "", # a path to a dir
127
+ ) -> TrainResult:
128
+ out_dir = create_tracker_dir(args.out_dir, "train", args.run_name)
129
+
130
+ timer_dir = os.path.join(out_dir, "time")
131
+
132
+ eric_timer = EricTimer(out_dir=timer_dir)
133
+
134
+ eric_train_validator = TrainValidator(
135
+ self.logger,
136
+ train_path=train_path,
137
+ args=args,
138
+ eval_path=eval_path,
139
+ resume_path=resume_path,
140
+ model=self.model,
141
+ )
142
+ # the validator may alter some of the parameters
143
+ args = eric_train_validator.args
144
+
145
+ tracker_state = None
146
+ if resume_path:
147
+ with eric_timer.section("resume", "load resume path"):
148
+ (
149
+ tracker_state,
150
+ args_dict,
151
+ model_tokenizer_path,
152
+ lr_sched_path,
153
+ ) = resume_training(resume_path)
154
+ args = EricTrainArgs(**args_dict)
155
+
156
+ self.config, self.tokenizer, self.model = self._load_model_components()
157
+
158
+ if eric_train_validator.train_source == "file":
159
+ with eric_timer.section("tokenize", "tokenize_train"):
160
+ current_train_tok_dir = self._tokenize_input_file(
161
+ train_path, out_dir, "train"
162
+ )
163
+ self.logger.info(
164
+ f"Train data has been tokenized and saved to {current_train_tok_dir}"
165
+ )
166
+
167
+ else: # has to be "train_tok_dir"
168
+ current_train_tok_dir = train_path
169
+
170
+ with eric_timer.section("tokenize", "load_train_data"):
171
+ train_dataloader, num_train_cases = get_tok_data(
172
+ current_train_tok_dir,
173
+ args.seed,
174
+ args.bs,
175
+ self._data_collator,
176
+ self.device,
177
+ )
178
+
179
+ skip_eval = False
180
+ if eric_train_validator.eval_source == "file":
181
+ with eric_timer.section("tokenize", "tokenize_eval"):
182
+ current_eval_tok_dir = self._tokenize_input_file(
183
+ eval_path, out_dir, "eval"
184
+ )
185
+ self.logger.info(
186
+ f"Eval data has been tokenized and saved to {current_eval_tok_dir}"
187
+ )
188
+ elif eric_train_validator.eval_source == "folder":
189
+ current_eval_tok_dir = eval_path
190
+ else:
191
+ self.logger.info(
192
+ "No evaluating data will be used. Provide eval_tok_dir or eval_filepath"
193
+ )
194
+ skip_eval = True
195
+
196
+ if not skip_eval:
197
+ with eric_timer.section("tokenize", "load_eval_data"):
198
+ eval_dataloader, num_eval_cases = get_tok_data(
199
+ current_eval_tok_dir,
200
+ args.seed,
201
+ args.eval_bs if args.eval_bs else args.bs,
202
+ self._data_collator,
203
+ self.device,
204
+ )
205
+ else:
206
+ num_eval_cases = 0
207
+ eval_dataloader = None
208
+
209
+ try:
210
+ self.model.to(self.device)
211
+ except Exception as e:
212
+ raise EricDeviceError(f"Failed to move model to {e}")
213
+
214
+ train_steps = get_num_training_steps(
215
+ train_cases=num_train_cases,
216
+ num_devices=1,
217
+ epochs=args.epochs,
218
+ gas=args.gas,
219
+ bs=args.bs,
220
+ )
221
+
222
+ eric_tracker = EricTracker(
223
+ args, train_steps, out_dir, tracker_state if resume_path else None
224
+ )
225
+
226
+ optim = get_optim(args, self.model, self.logger)
227
+
228
+ lr_sched = None
229
+ if args.lr_sched == "warmup_then_decay":
230
+ if train_steps < 8:
231
+ self.logger.info(
232
+ "You need to have at least 8 steps to use the 'warmup_then_decay' lr_sched. Falling back to 'constant'"
233
+ )
234
+ else:
235
+ num_warmup_steps = math.ceil(train_steps / 10) # 10% warmup
236
+
237
+ warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
238
+ optim,
239
+ start_factor=0.1, # 10% of the steps are for the warmup phase
240
+ end_factor=1.0,
241
+ total_iters=num_warmup_steps,
242
+ )
243
+
244
+ cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
245
+ optim,
246
+ T_max=train_steps - num_warmup_steps,
247
+ eta_min=0.1 * args.lr,
248
+ )
249
+
250
+ lr_sched = torch.optim.lr_scheduler.SequentialLR(
251
+ optim,
252
+ schedulers=[warmup_scheduler, cosine_scheduler],
253
+ milestones=[num_warmup_steps],
254
+ )
255
+
256
+ if lr_sched is None:
257
+ lr_sched = ConstantLR(optim, factor=1, total_iters=1)
258
+
259
+ skip_cases = 0
260
+ starting_epoch = 1
261
+
262
+ if resume_path:
263
+ cases_per_step = args.bs * args.gas
264
+ if tracker_state["last_checkpoint_step"] == train_steps:
265
+ raise EricResumeError(
266
+ "You provided a path to an already completed training run"
267
+ )
268
+ skip_cases = tracker_state["last_checkpoint_step"] * cases_per_step
269
+ current_epoch = tracker_state["epoch"]
270
+ starting_epoch = current_epoch
271
+ if current_epoch > 1:
272
+ skip_cases = skip_cases - (current_epoch - 1) * num_train_cases
273
+ try:
274
+ with eric_timer.section("resume", "load lr scheduler"):
275
+ lr_sched.load_state_dict(torch.load(lr_sched_path))
276
+ # for now we don't resume the optimizer's state as it was causing problems.
277
+ except Exception as e:
278
+ raise EricResumeError(
279
+ f"Could not load lr scheduler state from resume path: {e}"
280
+ )
281
+
282
+ self.logger.info(f"Model on {next(self.model.parameters()).device} ")
283
+
284
+ self.model.train()
285
+
286
+ train_loop_result: TrainResult = train_loop(
287
+ args=args,
288
+ train_dataloader=train_dataloader,
289
+ eric_tracker=eric_tracker,
290
+ train_steps=train_steps,
291
+ model=self.model,
292
+ optim=optim,
293
+ lr_sched=lr_sched,
294
+ eval_cases=num_eval_cases,
295
+ eval_dataloader=eval_dataloader,
296
+ tokenizer=self.tokenizer,
297
+ config=self.config,
298
+ skip_cases=skip_cases,
299
+ starting_epoch=starting_epoch,
300
+ eval_models=self.eval_models,
301
+ eric_timer=eric_timer,
302
+ precision_type=self.precision_type,
303
+ )
304
+
305
+ eric_tracker.close()
306
+
307
+ self._train_just_happened = True
308
+ eric_timer.report()
309
+ return train_loop_result
310
+
311
+ def eval(self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()) -> EvalResult:
312
+ eval_validator = EvalValidator(
313
+ model=self.model,
314
+ eval_path=eval_path,
315
+ args=args,
316
+ out_dir=args.out_dir,
317
+ logger=self.logger,
318
+ )
319
+
320
+ if eval_validator.eval_source == "file":
321
+ out_directory = create_tracker_dir(
322
+ args.out_dir, "eval", args.run_name
323
+ )
324
+ current_eval_tok_dir = self._tokenize_input_file(
325
+ eval_path, out_directory, "eval", in_eval=True
326
+ )
327
+ self.logger.info(
328
+ f"Eval data has been tokenized and saved to {current_eval_tok_dir}"
329
+ )
330
+ else: # has to be train_path:
331
+ current_eval_tok_dir = eval_path
332
+
333
+ dataloader, tok_data_len = get_tok_data(
334
+ current_eval_tok_dir,
335
+ args.seed,
336
+ args.bs,
337
+ self._data_collator,
338
+ self.device,
339
+ )
340
+
341
+ try:
342
+ self.model.to(self.device)
343
+ except Exception as e:
344
+ raise EricDeviceError(f"Failed to move model to {e}")
345
+
346
+ self.model.eval()
347
+ return eval_loop(
348
+ self.model,
349
+ dataloader,
350
+ tok_data_len,
351
+ args.bs,
352
+ eval_models=self.eval_models,
353
+ )
354
+
355
+ def tok(self, path: str, out_dir: str, args: TokArgs = TokArgs()):
356
+ _ = TokValidator(
357
+ input_data=path, out_dir=out_dir, args=args
358
+ )
359
+
360
+ file_infos = resolve_input_files(path)
361
+
362
+ try:
363
+ os.makedirs(out_dir, exist_ok=True)
364
+ except Exception as e:
365
+ raise EricIOError(f"Failed to make directory: {out_dir}. {e}")
366
+ output_paths, output_data_loc = prepare_output_locations(
367
+ out_dir, args.shards
368
+ )
369
+
370
+ num_cases = 0
371
+ p_num_files = tqdm(
372
+ total=len(file_infos), desc="# tokenizing shard #", position=0
373
+ )
374
+
375
+ for file, file_type in file_infos:
376
+ if args.max_cases >0 and num_cases >= args.max_cases:
377
+ break
378
+ try:
379
+ if file_type == "text":
380
+ dataset = load_dataset(
381
+ file_type,
382
+ data_files=[file],
383
+ split="train",
384
+ sample_by="document",
385
+ )
386
+ else:
387
+ dataset = load_dataset(file_type, data_files=[file], split="train")
388
+
389
+ except Exception as e:
390
+ raise (
391
+ EricDatasetError(f"Error loading {file} of type {file_type}: {e}")
392
+ )
393
+
394
+ # Trim dataset if necessary
395
+ remaining = (
396
+ args.max_cases - num_cases
397
+ if args.max_cases > 0
398
+ else len(dataset)
399
+ )
400
+ trimmed_count = min(len(dataset), remaining)
401
+ dataset = dataset.select(range(trimmed_count))
402
+ procs = get_procs(args.procs)
403
+ # Tokenize trimmed data
404
+ tokenized = self._tok_function(
405
+ dataset, args=args, file_type=file_type, procs=procs
406
+ )
407
+
408
+ save_json_tok_data(
409
+ tokenized,
410
+ output_data_loc,
411
+ args.shards,
412
+ self._format_tokenized_example,
413
+ )
414
+
415
+ num_cases += len(tokenized)
416
+
417
+ p_num_files.update(1)
418
+
419
+ write_details_file(out_dir, num_cases, output_paths)
420
+
421
+ for f in output_data_loc:
422
+ f.close()
423
+
424
+ def save(self, path: str):
425
+ if self.model is None:
426
+ raise EricNoModelError(
427
+ f"No model found. Provide a model_name or PreTrainedModel when instantiating {self.__class__.__name__}."
428
+ )
429
+ try:
430
+ self.model.save_pretrained(path)
431
+ except Exception as e:
432
+ raise EricSaveError(f"Error saving model to {path}: {e}")
433
+ try:
434
+ self.tokenizer.save_pretrained(path)
435
+ except Exception as e:
436
+ raise EricSaveError(f"Error saving tokenizer to {path}: {e}")
437
+ try:
438
+ self.config.save_pretrained(path)
439
+ except Exception as e:
440
+ raise EricSaveError(f"Error saving config to {path}: {e}")
441
+
442
+ self._prep_model() # Reset self.model.generation_config with the Eric Transformer values for CHAT, GEN and TT
443
+
444
+ def push(self, repo_id: str, private: bool = True):
445
+ if self.model is None:
446
+ raise EricNoModelError(
447
+ f"No model found. Provide a model_name or PreTrainedModel when instantiating {self.__class__.__name__}."
448
+ )
449
+
450
+ api = HfApi()
451
+ try:
452
+ api.create_repo(repo_id, exist_ok=True, private=private)
453
+ except Exception as e:
454
+ self.logger.warning(f"Could not crate repo {e}")
455
+ return
456
+ try:
457
+ has_readme = api.file_exists(repo_id, "README.md")
458
+ except Exception as e:
459
+ self.logger.warning(f"Could not info: {e}")
460
+ return
461
+
462
+ if not has_readme:
463
+ readme_text = self._get_readme(repo_id)
464
+ try:
465
+ self.logger.info("Pushing README...")
466
+
467
+ api.upload_file(
468
+ path_or_fileobj=readme_text.encode("utf-8"),
469
+ path_in_repo="README.md",
470
+ repo_id=repo_id,
471
+ repo_type="model",
472
+ )
473
+
474
+ except Exception as e:
475
+ # Don’t fail the whole push if README upload fails; just warn.
476
+ self.logger.warning(f"Error pushing README: {e}")
477
+
478
+ try:
479
+ self.logger.info("Pushing model...")
480
+ self.model.push_to_hub(
481
+ repo_id,
482
+ private=private,
483
+ commit_message="Uploaded model from Eric Transformer",
484
+ )
485
+ except Exception as e:
486
+ raise EricPushError(f"Error pushing model: {e}")
487
+ try:
488
+ self.logger.info("Pushing tokenizer...")
489
+ self.tokenizer.push_to_hub(
490
+ repo_id,
491
+ private=private,
492
+ commit_message="Uploaded tokenizer from Eric Transformer",
493
+ )
494
+ except Exception as e:
495
+ raise EricPushError(f"Error pushing tokenizer: {e}")
496
+ try:
497
+ self.logger.info("Pushing config...")
498
+ self.config.push_to_hub(
499
+ repo_id,
500
+ private=private,
501
+ commit_message="Uploaded config from Eric Transformer",
502
+ )
503
+ except Exception as e:
504
+ raise EricPushError(f"Error pushing config: {e}")
505
+
506
+ def _tokenize_input_file(
507
+ self, train_path: str, out_dir: str, label: str, in_eval: bool = False
508
+ ) -> str:
509
+ dir_name = f"tok_{label}_data" if in_eval else f"data/tok_{label}_data"
510
+ tok_dir = os.path.join(out_dir, dir_name)
511
+ try:
512
+ os.makedirs(tok_dir, exist_ok=True)
513
+ except Exception as e:
514
+ raise EricIOError(f"error making directory ({tok_dir}): {e}")
515
+
516
+ self.tok(train_path, tok_dir)
517
+ self.logger.info(
518
+ f"{label.capitalize()} data has been tokenized and saved to {tok_dir}"
519
+ )
520
+ return tok_dir
521
+
522
+ def _get_model_ready_inference(self):
523
+ if self._train_just_happened:
524
+ self.logger.info(f"Moving model to {self.device}")
525
+ try:
526
+ self.model.to(self.device)
527
+ except Exception as e:
528
+ raise EricDeviceError(f"Failed to move model to device: {e}")
529
+
530
+ self._train_just_happened = False
531
+
532
+ @staticmethod
533
+ def tok_dir_to_dataset(tok_dir: str) -> Dataset:
534
+ return tok_dir_to_dataset(tok_dir)
@@ -0,0 +1 @@
1
+ from erictransformer.eval_models.eval_model import EvalModel, TCAccuracyEvalModel
@@ -0,0 +1,75 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional
3
+
4
+ import torch
5
+ from transformers.modeling_outputs import ModelOutput
6
+
7
+ from erictransformer.exceptions import EricEvalModelError
8
+
9
+
10
+ class EvalModel(ABC):
11
+ def __init__(self, name: str):
12
+ self.name = name
13
+
14
+ @abstractmethod
15
+ def reset(self) -> None:
16
+ pass
17
+
18
+ @abstractmethod
19
+ def __call__(self, batch: Dict[str, torch.Tensor], outputs: ModelOutput) -> None:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def result(self) -> Optional[float]:
24
+ pass
25
+
26
+ @abstractmethod
27
+ def _confirm_batch_params(
28
+ self, batch: Dict[str, torch.Tensor], outputs: ModelOutput
29
+ ) -> None:
30
+ pass
31
+
32
+
33
+ class TCAccuracyEvalModel(EvalModel):
34
+ def __init__(self):
35
+ self.correct_preds = 0
36
+ self.total_preds = 0
37
+ self.name = "accuracy"
38
+ super().__init__(self.name)
39
+
40
+ def reset(self) -> None:
41
+ self.correct_preds = 0
42
+ self.total_preds = 0
43
+
44
+ def __call__(self, batch: Dict[str, torch.Tensor], outputs: ModelOutput) -> None:
45
+ try:
46
+ self._confirm_batch_params(batch, outputs)
47
+ logits = outputs.logits
48
+ labels = batch["labels"]
49
+ preds = torch.argmax(logits, dim=-1)
50
+ self.correct_preds += (preds == labels).sum().item()
51
+ self.total_preds += labels.size(0)
52
+ except Exception as e:
53
+ raise EricEvalModelError(f"error calling {self.__class__.__name__}: {e}")
54
+
55
+ def result(self) -> Optional[float]:
56
+ return self.correct_preds / self.total_preds if self.total_preds > 0 else None
57
+
58
+ def _confirm_batch_params(
59
+ self, batch: Dict[str, torch.Tensor], outputs: ModelOutput
60
+ ) -> None:
61
+ if "labels" not in batch:
62
+ raise EricEvalModelError("Batch must contain a 'labels' key.")
63
+ if not isinstance(batch["labels"], torch.Tensor):
64
+ raise EricEvalModelError("batch['labels'] must be a torch.Tensor.")
65
+ if not hasattr(outputs, "logits"):
66
+ raise EricEvalModelError("ModelOutput must have a 'logits' attribute.")
67
+ if not isinstance(outputs.logits, torch.Tensor):
68
+ raise EricEvalModelError("outputs.logits must be a torch.Tensor.")
69
+
70
+ bs_logits = outputs.logits.shape[0]
71
+ bs_labels = batch["labels"].shape[0]
72
+ if bs_logits != bs_labels:
73
+ raise EricEvalModelError(
74
+ f"Mismatch in batch size: logits has {bs_logits}, labels has {bs_labels}"
75
+ )
@@ -0,0 +1,19 @@
1
+ from erictransformer.exceptions.eric_exceptions import (
2
+ EricChatTemplateError,
3
+ EricDatasetError,
4
+ EricDeviceError,
5
+ EricEvalError,
6
+ EricEvalModelError,
7
+ EricInferenceError,
8
+ EricInputError,
9
+ EricIOError,
10
+ EricLoadModelError,
11
+ EricLoadPipelineError,
12
+ EricLoadTokenizerError,
13
+ EricNoModelError,
14
+ EricPushError,
15
+ EricResumeError,
16
+ EricSaveError,
17
+ EricTokenizationError,
18
+ EricTrainError,
19
+ )
@@ -0,0 +1,74 @@
1
+ class EricTransformerError(Exception):
2
+ pass
3
+
4
+
5
+ class EricInputError(EricTransformerError):
6
+ pass
7
+
8
+
9
+ class EricLoadModelError(EricTransformerError):
10
+ pass
11
+
12
+
13
+ class EricLoadTokenizerError(EricTransformerError):
14
+ pass
15
+
16
+
17
+ class EricLoadPipelineError(EricTransformerError):
18
+ pass
19
+
20
+
21
+ class EricResumeError(EricTransformerError):
22
+ pass
23
+
24
+
25
+ class EricInferenceError(EricTransformerError):
26
+ pass
27
+
28
+
29
+ class EricDeviceError(EricTransformerError):
30
+ pass
31
+
32
+
33
+ class EricIOError(EricTransformerError):
34
+ pass
35
+
36
+
37
+ class EricTokenizationError(EricTransformerError):
38
+ pass
39
+
40
+
41
+ class EricDatasetError(EricTransformerError):
42
+ pass
43
+
44
+
45
+ class EricChatTemplateError(EricTransformerError):
46
+ pass
47
+
48
+
49
+ class EricEvalModelError(EricTransformerError):
50
+ pass
51
+
52
+
53
+ class EricSaveError(EricTransformerError):
54
+ pass
55
+
56
+
57
+ class EricNoModelError(EricTransformerError):
58
+ pass
59
+
60
+
61
+ class EricPushError(EricTransformerError):
62
+ pass
63
+
64
+
65
+ class EricPlotError(EricTransformerError):
66
+ pass
67
+
68
+
69
+ class EricEvalError(EricTransformerError):
70
+ pass
71
+
72
+
73
+ class EricTrainError(EricTransformerError):
74
+ pass
@@ -0,0 +1,2 @@
1
+ from erictransformer.loops.eval_loop import EvalResult, eval_loop
2
+ from erictransformer.loops.train_loop import TrainResult, train_loop