erictransformer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. erictransformer/__init__.py +44 -0
  2. erictransformer/args/__init__.py +7 -0
  3. erictransformer/args/eric_args.py +50 -0
  4. erictransformer/eric_tasks/__init__.py +47 -0
  5. erictransformer/eric_tasks/args/__init__.py +16 -0
  6. erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
  7. erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
  8. erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
  9. erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
  10. erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
  11. erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
  12. erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
  13. erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
  14. erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
  15. erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
  16. erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
  17. erictransformer/eric_tasks/chat_templates/convert.py +67 -0
  18. erictransformer/eric_tasks/eric_chat.py +369 -0
  19. erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
  20. erictransformer/eric_tasks/eric_generation.py +243 -0
  21. erictransformer/eric_tasks/eric_text_classification.py +231 -0
  22. erictransformer/eric_tasks/eric_text_to_text.py +283 -0
  23. erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
  24. erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
  25. erictransformer/eric_tasks/misc/__init__.py +11 -0
  26. erictransformer/eric_tasks/misc/call_utils.py +69 -0
  27. erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
  28. erictransformer/eric_tasks/misc/rag.py +17 -0
  29. erictransformer/eric_tasks/results/__init__.py +6 -0
  30. erictransformer/eric_tasks/results/call_results.py +30 -0
  31. erictransformer/eric_tasks/tok/__init__.py +0 -0
  32. erictransformer/eric_tasks/tok/tok_functions.py +118 -0
  33. erictransformer/eric_tracker/__init__.py +1 -0
  34. erictransformer/eric_tracker/eric_tracker.py +256 -0
  35. erictransformer/eric_tracker/save_plot.py +422 -0
  36. erictransformer/eric_transformer.py +534 -0
  37. erictransformer/eval_models/__init__.py +1 -0
  38. erictransformer/eval_models/eval_model.py +75 -0
  39. erictransformer/exceptions/__init__.py +19 -0
  40. erictransformer/exceptions/eric_exceptions.py +74 -0
  41. erictransformer/loops/__init__.py +2 -0
  42. erictransformer/loops/eval_loop.py +111 -0
  43. erictransformer/loops/train_loop.py +310 -0
  44. erictransformer/utils/__init__.py +21 -0
  45. erictransformer/utils/init/__init__.py +5 -0
  46. erictransformer/utils/init/get_components.py +204 -0
  47. erictransformer/utils/init/get_device.py +22 -0
  48. erictransformer/utils/init/get_logger.py +15 -0
  49. erictransformer/utils/load_from_repo_or_path.py +14 -0
  50. erictransformer/utils/test/__init__.py +1 -0
  51. erictransformer/utils/test/debug_hook.py +20 -0
  52. erictransformer/utils/timer/__init__.py +1 -0
  53. erictransformer/utils/timer/eric_timer.py +145 -0
  54. erictransformer/utils/tok_data/__init__.py +8 -0
  55. erictransformer/utils/tok_data/num_proc.py +15 -0
  56. erictransformer/utils/tok_data/save_tok_data.py +36 -0
  57. erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
  58. erictransformer/utils/tok_data/tok_helpers.py +79 -0
  59. erictransformer/utils/train/__init__.py +6 -0
  60. erictransformer/utils/train/confirm_optimizer.py +18 -0
  61. erictransformer/utils/train/create_dir.py +72 -0
  62. erictransformer/utils/train/get_num_training_steps.py +15 -0
  63. erictransformer/utils/train/get_precision.py +22 -0
  64. erictransformer/utils/train/get_tok_data.py +105 -0
  65. erictransformer/utils/train/resume.py +62 -0
  66. erictransformer/validator/__init__.py +11 -0
  67. erictransformer/validator/eric/__init__.py +2 -0
  68. erictransformer/validator/eric/eval_validator.py +75 -0
  69. erictransformer/validator/eric/train_validator.py +143 -0
  70. erictransformer/validator/eric_validator.py +10 -0
  71. erictransformer/validator/tasks/__init__.py +5 -0
  72. erictransformer/validator/tasks/chat_validator.py +28 -0
  73. erictransformer/validator/tasks/gen_validator.py +28 -0
  74. erictransformer/validator/tasks/task_validator.py +54 -0
  75. erictransformer/validator/tasks/tc_validator.py +45 -0
  76. erictransformer/validator/tasks/tt_validator.py +28 -0
  77. erictransformer/validator/tok/__init__.py +1 -0
  78. erictransformer/validator/tok/tok_validator.py +23 -0
  79. erictransformer-0.0.1.dist-info/METADATA +72 -0
  80. erictransformer-0.0.1.dist-info/RECORD +83 -0
  81. erictransformer-0.0.1.dist-info/WHEEL +5 -0
  82. erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
  83. erictransformer-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,111 @@
1
+ import math
2
+ import sys
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ from torch.amp import autocast
8
+ from torch.nn import Module
9
+ from torch.utils.data import DataLoader
10
+ from tqdm.auto import tqdm
11
+
12
+ from erictransformer.eval_models import EvalModel
13
+ from erictransformer.exceptions import EricEvalError
14
+ from erictransformer.utils import get_precision
15
+
16
+
17
+ @dataclass(kw_only=True)
18
+ class EvalResult:
19
+ loss: float
20
+ metrics: Optional[dict] = None
21
+
22
+
23
+ def eval_loop(
24
+ model: Module,
25
+ dataloader: DataLoader,
26
+ eval_tok_data_cases: int,
27
+ eval_bs: int,
28
+ eval_models: List[EvalModel],
29
+ precision: str = "auto",
30
+ ) -> EvalResult:
31
+ total_loss = 0.0
32
+ total_examples = 0
33
+ pbar = tqdm(
34
+ total=math.ceil(eval_tok_data_cases / eval_bs),
35
+ desc="Model Evaluating",
36
+ position=1,
37
+ leave=False,
38
+ disable=not sys.stdout.isatty(),
39
+ )
40
+
41
+ for eval_model in eval_models:
42
+ eval_model.reset()
43
+
44
+ model.eval()
45
+
46
+ device = next(model.parameters()).device
47
+ precision_type = get_precision(device=device)
48
+
49
+ with torch.no_grad():
50
+ for batch in dataloader:
51
+ try:
52
+ try:
53
+ input_ids = batch["input_ids"].to(device, non_blocking=True)
54
+ attention_mask = batch["attention_mask"].to(
55
+ device, non_blocking=True
56
+ )
57
+ labels = batch["labels"].to(device, non_blocking=True)
58
+ except Exception as e:
59
+ raise EricEvalError(f"Failed to extract inputs from batch: {e}")
60
+
61
+ if device.type == "cuda":
62
+ dtype = (
63
+ torch.bfloat16 if precision_type == "bf16" else
64
+ torch.float16 if precision_type == "fp16" else
65
+ torch.float32
66
+ )
67
+ with autocast(device_type="cuda", enabled=(dtype != torch.float32), dtype=dtype):
68
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
69
+ else:
70
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
71
+
72
+ bs = batch["input_ids"].size(0)
73
+ result = outputs.loss.item() * bs
74
+ except Exception as e:
75
+ raise EricEvalError(f"Error processing batch: {e}")
76
+
77
+ if not math.isnan(result):
78
+ total_loss += result
79
+ total_examples += bs
80
+ if eval_models:
81
+ batch["input_ids"] = batch["input_ids"].to(device)
82
+ batch["attention_mask"] = batch["attention_mask"].to(device)
83
+ batch["labels"] = batch["labels"].to(device)
84
+
85
+ for eval_model in eval_models:
86
+ eval_model(batch, outputs)
87
+
88
+ pbar.update(1)
89
+
90
+ pbar.close()
91
+
92
+ custom_metrics = {}
93
+ for eval_model in eval_models:
94
+ name = eval_model.name
95
+ if name in custom_metrics:
96
+ for i in range(0, 1000):
97
+ new_name = f"{eval_model.name}-{i}"
98
+ if new_name not in custom_metrics:
99
+ name = new_name
100
+ break
101
+ else:
102
+ raise EricEvalError(
103
+ f"Couldn't find unique name for eval model named {eval_model.name}"
104
+ )
105
+
106
+ custom_metrics[name] = eval_model.result()
107
+ eval_model.reset()
108
+
109
+ average_loss = total_loss / total_examples if total_examples else 0.0
110
+
111
+ return EvalResult(loss=average_loss, metrics=custom_metrics)
@@ -0,0 +1,310 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+ from torch.nn import Module
7
+ from torch.optim import Optimizer
8
+ from torch.optim.lr_scheduler import LRScheduler
9
+ from torch.utils.data import DataLoader
10
+ from tqdm.auto import tqdm
11
+ from transformers import PretrainedConfig, PreTrainedTokenizerBase
12
+
13
+ from erictransformer.args.eric_args import EricTrainArgs
14
+ from erictransformer.eval_models import EvalModel
15
+ from erictransformer.exceptions import EricTrainError
16
+ from erictransformer.eric_tracker.eric_tracker import EricTracker
17
+ from erictransformer.loops.eval_loop import eval_loop
18
+ from erictransformer.utils import EricTimer
19
+ from erictransformer.utils.test import DebugHook
20
+
21
+
22
+ @dataclass
23
+ class TrainResult:
24
+ final_train_loss: float
25
+ final_eval_loss: Optional[float] = None
26
+ best_eval_loss: Optional[float] = None
27
+
28
+
29
+ def train_loop(
30
+ args: EricTrainArgs,
31
+ train_dataloader: DataLoader,
32
+ eric_tracker: EricTracker,
33
+ train_steps: int,
34
+ model: Module,
35
+ optim: Optimizer,
36
+ lr_sched: LRScheduler,
37
+ eval_cases: int,
38
+ eval_dataloader: DataLoader,
39
+ tokenizer: PreTrainedTokenizerBase,
40
+ config: PretrainedConfig,
41
+ skip_cases: int,
42
+ starting_epoch: int,
43
+ eval_models: List[EvalModel],
44
+ eric_timer: EricTimer,
45
+ precision_type: str,
46
+ ) -> TrainResult:
47
+ best_tokenizer_config_saved = False
48
+ checkpoint_tokenizer_config_saved = False
49
+ first_epoch_train_iter = iter(train_dataloader)
50
+
51
+ if skip_cases:
52
+ try:
53
+ skip_batches = max(1, int(skip_cases / args.bs))
54
+ for _ in tqdm(range(skip_batches)):
55
+ next(first_epoch_train_iter, None)
56
+ except Exception as e:
57
+ raise EricTrainError(f"Failed to skip initial training cases: {e}")
58
+
59
+ device = next(model.parameters()).device
60
+
61
+ gas = max(1, args.gas)
62
+
63
+ start_step = eric_tracker.state.current_step
64
+
65
+ for epoch in range(starting_epoch - 1, int(args.epochs)):
66
+ epoch_iter = (
67
+ first_epoch_train_iter
68
+ if epoch == starting_epoch - 1
69
+ else iter(train_dataloader)
70
+ )
71
+
72
+ batch_idx = 0
73
+ while True:
74
+ if eric_tracker.state.current_step > train_steps:
75
+ break
76
+
77
+ with eric_timer.section("training_core", "iter"):
78
+ try:
79
+ batch = next(epoch_iter)
80
+ except StopIteration:
81
+ break
82
+
83
+ try:
84
+ with eric_timer.section("training_core", "to_device"):
85
+ input_ids = batch["input_ids"].to(device, non_blocking=True)
86
+ attention_mask = batch["attention_mask"].to(
87
+ device, non_blocking=True
88
+ )
89
+ labels = batch["labels"].to(device, non_blocking=True)
90
+ except Exception as e:
91
+ raise EricTrainError(f"Failed to extract inputs from batch: {e}")
92
+
93
+ try:
94
+ with eric_timer.section("training_core", "forward"):
95
+ if device.type == "cuda":
96
+ with torch.autocast(
97
+ device_type="cuda",
98
+ enabled=precision_type != "fp32",
99
+ dtype=(
100
+ torch.float16
101
+ if precision_type == "fp16"
102
+ else torch.bfloat16
103
+ ),
104
+ ):
105
+ outputs = model(
106
+ input_ids=input_ids,
107
+ attention_mask=attention_mask,
108
+ labels=labels,
109
+ )
110
+ else:
111
+ outputs = model(
112
+ input_ids=input_ids,
113
+ attention_mask=attention_mask,
114
+ labels=labels,
115
+ )
116
+
117
+ with eric_timer.section("training_core", "get loss"):
118
+ raw_loss = outputs.loss.detach()
119
+ loss = outputs.loss / gas
120
+
121
+ except Exception as e:
122
+ raise EricTrainError(f"Model forward pass failed: {e}")
123
+
124
+ try:
125
+ with eric_timer.section("training_core", "backwards"):
126
+ loss.backward()
127
+ except Exception as e:
128
+ raise EricTrainError(f"Backward pass failed: {e}")
129
+
130
+ try:
131
+ do_step = ((batch_idx + 1) % gas == 0) or (
132
+ (batch_idx + 1) == len(train_dataloader)
133
+ )
134
+
135
+ if do_step:
136
+ if precision_type == "fp16":
137
+ # for now we don't clip gradients when using bf16 or fp32 since this
138
+ # operation takes a noticeable amount of time and they're stable enough.
139
+ with eric_timer.section("training_core", "clip gradients"):
140
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
141
+
142
+ with eric_timer.section("training_core", "optim.step()"):
143
+ optim.step()
144
+ with eric_timer.section("training_core", "lr_sched.step()"):
145
+ lr_sched.step()
146
+ with eric_timer.section("training_core", "optim.zero_grad()"):
147
+ optim.zero_grad()
148
+ eric_tracker.optim_step()
149
+
150
+ except Exception as e:
151
+ raise EricTrainError(f"optim or LR scheduler step failed: {e}")
152
+
153
+ try:
154
+ do_eval = bool(eric_tracker.time_to_eval() and eval_cases)
155
+
156
+ if do_eval:
157
+ with eric_timer.section("training_extra", "eval loop"):
158
+ model.eval()
159
+ eval_result = eval_loop(
160
+ model,
161
+ eval_dataloader,
162
+ eval_cases,
163
+ args.eval_bs
164
+ if args.eval_bs
165
+ else args.bs,
166
+ eval_models=eval_models,
167
+ )
168
+ model.train()
169
+
170
+ global_eval_loss = float(eval_result.loss)
171
+
172
+ is_best_model = eric_tracker.set_eval_loss(global_eval_loss)
173
+ eric_tracker.set_metrics(eval_result.metrics)
174
+
175
+ if args.save_best and is_best_model:
176
+ with eric_timer.section("training_extra", "save best model"):
177
+ try:
178
+ path = eric_tracker.tracker_paths.best_model_path
179
+ model.save_pretrained(path)
180
+ except Exception as e:
181
+ path = eric_tracker.tracker_paths.best_model_path
182
+ raise EricTrainError(
183
+ f"Failed to save best model to: {path} | {e}"
184
+ )
185
+
186
+ try:
187
+ state_path = (
188
+ eric_tracker.tracker_paths.best_model_state_path
189
+ )
190
+ with open(state_path, "w") as f:
191
+ json.dump(
192
+ eric_tracker.state.to_dict(), f, indent=2
193
+ )
194
+ except Exception as e:
195
+ state_path = (
196
+ eric_tracker.tracker_paths.best_model_state_path
197
+ )
198
+ raise EricTrainError(
199
+ f"Failed to save best model state to: {state_path} | {e}"
200
+ )
201
+
202
+ if not best_tokenizer_config_saved:
203
+ try:
204
+ tokenizer.save_pretrained(
205
+ eric_tracker.tracker_paths.best_model_path
206
+ )
207
+ except Exception as e:
208
+ path = eric_tracker.tracker_paths.best_model_path
209
+ raise EricTrainError(
210
+ f"Failed to save tokenizer to: {path} | {e}"
211
+ )
212
+ try:
213
+ config.save_pretrained(
214
+ eric_tracker.tracker_paths.best_model_path
215
+ )
216
+ except Exception as e:
217
+ path = eric_tracker.tracker_paths.best_model_path
218
+ raise EricTrainError(
219
+ f"Failed to save config to: {path} | {e}"
220
+ )
221
+ best_tokenizer_config_saved = True
222
+ except Exception as e:
223
+ raise EricTrainError(f"Evaluation loop failed: {e}")
224
+
225
+ try:
226
+ do_ckpt = eric_tracker.time_to_checkpoint()
227
+
228
+ if do_ckpt:
229
+ with eric_timer.section("training_extra", "checkpoint"):
230
+ checkpoint_path = eric_tracker.tracker_paths.checkpoint_path
231
+ try:
232
+ model.save_pretrained(checkpoint_path)
233
+ except Exception as e:
234
+ raise EricTrainError(
235
+ f"Failed to save checkpoint model to: {checkpoint_path} | {e}"
236
+ )
237
+
238
+ if not checkpoint_tokenizer_config_saved:
239
+ try:
240
+ tokenizer.save_pretrained(checkpoint_path)
241
+ except Exception as e:
242
+ raise EricTrainError(
243
+ f"Failed to save tokenizer to: {checkpoint_path} |{e}"
244
+ )
245
+ try:
246
+ config.save_pretrained(checkpoint_path)
247
+ except Exception as e:
248
+ raise EricTrainError(
249
+ f"Failed to save config to: {checkpoint_path} | {e}"
250
+ )
251
+ checkpoint_tokenizer_config_saved = True
252
+
253
+ try:
254
+ # Reloading the optimizer was causing problems. So for now we restart it.
255
+ #torch.save(
256
+ # optim.state_dict(),
257
+ # eric_tracker.tracker_paths.optim_path,
258
+ # )
259
+
260
+ torch.save(
261
+ lr_sched.state_dict(),
262
+ eric_tracker.tracker_paths.lr_sched_path,
263
+ )
264
+
265
+ except Exception as e:
266
+ raise EricTrainError(
267
+ f"Failed to save lr scheduler state to: "
268
+ f"{eric_tracker.tracker_paths.optim_path} | {e}"
269
+ )
270
+ except Exception as e:
271
+ raise EricTrainError(str(e))
272
+
273
+ try:
274
+ with eric_timer.section("training_extra", "logging"):
275
+ if (
276
+ eric_tracker.time_to_log()
277
+ or eric_tracker.time_to_eval()
278
+ or eric_tracker.time_to_checkpoint()
279
+ ):
280
+ eric_timer.report()
281
+
282
+ num_tokens = int(attention_mask.sum().item())
283
+ eric_tracker.step(
284
+ raw_loss, lr_sched.get_last_lr()[0], num_tokens=num_tokens
285
+ )
286
+ except Exception as e:
287
+ raise EricTrainError(f"Tracker step failed: {e}")
288
+
289
+ debug_hook_post_checkpoint()
290
+
291
+ batch_idx += 1
292
+
293
+ eric_tracker.mark_epoch()
294
+
295
+ if eric_tracker.state.current_step >= train_steps:
296
+ break
297
+
298
+ debug_hook_steps({
299
+ "start_step": start_step,
300
+ "total_steps": eric_tracker.state.current_step
301
+ })
302
+
303
+ return TrainResult(
304
+ final_train_loss=eric_tracker.state.train_loss,
305
+ final_eval_loss=eric_tracker.state.eval_loss if eval_cases else None,
306
+ best_eval_loss=eric_tracker.state.best_eval_loss if eval_cases else None,
307
+ )
308
+
309
+ debug_hook_post_checkpoint = DebugHook()
310
+ debug_hook_steps = DebugHook()
@@ -0,0 +1,21 @@
1
+ from erictransformer.utils.init import get_model_components, et_retrieve_tokenizer
2
+ from erictransformer.utils.init.get_device import et_get_device
3
+ from erictransformer.utils.init.get_logger import et_get_logger
4
+ from erictransformer.utils.timer import EricTimer
5
+ from erictransformer.utils.tok_data import (
6
+ get_procs,
7
+ prepare_output_locations,
8
+ resolve_input_files,
9
+ tok_dir_to_dataset,
10
+ write_details_file,
11
+ )
12
+ from erictransformer.utils.tok_data.save_tok_data import save_json_tok_data
13
+ from erictransformer.utils.train import (
14
+ create_tracker_dir,
15
+ get_num_training_steps,
16
+ get_optim,
17
+ get_precision,
18
+ get_tok_data,
19
+ make_dir,
20
+ resume_training,
21
+ )
@@ -0,0 +1,5 @@
1
+ from erictransformer.utils.init.get_components import (
2
+ get_model_components,
3
+ get_model_components_tc,
4
+ et_retrieve_tokenizer,
5
+ )
@@ -0,0 +1,204 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ from transformers import (
5
+ AutoConfig,
6
+ AutoModel,
7
+ AutoTokenizer,
8
+ PreTrainedModel,
9
+ PreTrainedTokenizerBase,
10
+ )
11
+
12
+ from erictransformer.exceptions import EricLoadModelError, EricLoadTokenizerError
13
+
14
+
15
+ def _get_torch_dtype(torch_dtype: str) -> Union[torch.dtype, str]:
16
+ if torch_dtype == "fp32":
17
+ return torch.float32
18
+ elif torch_dtype == "fp16":
19
+ return torch.float16
20
+ elif torch_dtype == "bf16":
21
+ return torch.bfloat16
22
+ else:
23
+ raise ValueError(
24
+ f"Invalid torch_dtype {torch_dtype}. Provide one of auto, fp32, fp16, or bf16."
25
+ )
26
+
27
+
28
+ def get_model(
29
+ model_name_path: Union[str, PreTrainedModel],
30
+ model_class: AutoModel,
31
+ trust_remote_code: bool,
32
+ precision,
33
+ ) -> PreTrainedModel:
34
+ try:
35
+ if isinstance(model_name_path, PreTrainedModel):
36
+ return model_name_path
37
+
38
+ loaded_config = AutoConfig.from_pretrained(
39
+ model_name_path, trust_remote_code=trust_remote_code
40
+ )
41
+ torch_dtype = _get_torch_dtype(precision)
42
+
43
+ model = model_class.from_pretrained(
44
+ model_name_path,
45
+ config=loaded_config,
46
+ trust_remote_code=trust_remote_code,
47
+ dtype=torch_dtype,
48
+ )
49
+
50
+ return model
51
+
52
+ except Exception as e:
53
+ raise EricLoadModelError(f"Failed to load model from '{model_name_path}': {e}")
54
+
55
+
56
+ def et_retrieve_tokenizer(
57
+ tokenizer_path: str, trust_remote_code: bool
58
+ ):
59
+ try:
60
+ tokenizer = AutoTokenizer.from_pretrained(
61
+ tokenizer_path,
62
+ trust_remote_code=trust_remote_code,
63
+ )
64
+
65
+ return tokenizer
66
+ except Exception as e:
67
+ raise EricLoadTokenizerError(
68
+ f"Failed to load tokenizer from '{tokenizer_path}': {e}"
69
+ )
70
+
71
+
72
+ def get_tokenizer(
73
+ model_name_path: str,
74
+ tokenizer_path: Union[str, PreTrainedTokenizerBase],
75
+ trust_remote_code: bool,
76
+ ) -> PreTrainedTokenizerBase:
77
+ if (
78
+ tokenizer_path is None
79
+ ): # by default we use the same value provided to model_name_path
80
+ tokenizer = et_retrieve_tokenizer(
81
+ tokenizer_path=model_name_path,
82
+ trust_remote_code=trust_remote_code,
83
+ )
84
+
85
+ elif isinstance(tokenizer_path, str):
86
+ tokenizer = et_retrieve_tokenizer(
87
+ tokenizer_path=tokenizer_path,
88
+ trust_remote_code=trust_remote_code,
89
+ )
90
+
91
+ elif isinstance(tokenizer_path, PreTrainedTokenizerBase):
92
+ # tokenizer provided directly
93
+ tokenizer = tokenizer_path
94
+ # maybe remove the following code since if an advanced user provides their own tokenizer they might want full control
95
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
96
+ tokenizer.pad_token = tokenizer.eos_token
97
+ else:
98
+ raise ValueError(
99
+ "Invalid tokenizer_path. Provide a None, a string or AutoTokenizer"
100
+ )
101
+
102
+ if tokenizer.pad_token is None:
103
+ if tokenizer.eos_token is not None:
104
+ tokenizer.pad_token = tokenizer.eos_token
105
+ else:
106
+ tokenizer.add_special_tokens({"pad_token": "<pad>"})
107
+
108
+ return tokenizer
109
+
110
+
111
+ def get_model_components(
112
+ model_name_path: Union[str, PreTrainedModel, None],
113
+ trust_remote_code: bool,
114
+ model_class: AutoModel,
115
+ tokenizer_path: Union[str, PreTrainedTokenizerBase],
116
+ precision: str,
117
+ ):
118
+ tokenizer = get_tokenizer(
119
+ model_name_path=model_name_path,
120
+ tokenizer_path=tokenizer_path,
121
+ trust_remote_code=trust_remote_code,
122
+ )
123
+
124
+ if model_name_path is not None:
125
+ model = get_model(
126
+ model_name_path=model_name_path,
127
+ model_class=model_class,
128
+ trust_remote_code=trust_remote_code,
129
+ precision=precision,
130
+ )
131
+ config = model.config
132
+ else:
133
+ model = None
134
+ config = None
135
+
136
+ return config, tokenizer, model
137
+
138
+
139
+ def get_model_components_tc(
140
+ model_name_path: Union[str, PreTrainedModel, None],
141
+ trust_remote_code: bool,
142
+ model_class: AutoModel,
143
+ tokenizer_path: Union[str, PreTrainedTokenizerBase],
144
+ precision: str,
145
+ labels: Optional[List[str]] = None,
146
+ ):
147
+
148
+ config = AutoConfig.from_pretrained(
149
+ model_name_path, trust_remote_code=trust_remote_code
150
+ )
151
+ reset_labels = False # set to True if we should reset the labels to the size of num_labels
152
+
153
+ config_id2label = getattr(config, "id2label", None)
154
+ config_label2id = getattr(config, "label2id", None)
155
+
156
+
157
+ if not config_id2label or not config_label2id:
158
+ reset_labels = True
159
+ elif labels is not None:
160
+ reset_labels = True
161
+ else:
162
+ id2label_length = len(config_id2label)
163
+ label2id_length = len(config_label2id)
164
+ if id2label_length != label2id_length:
165
+ reset_labels = True
166
+
167
+
168
+ if reset_labels:
169
+ if labels is None:
170
+ labels = ["LABEL_0", "LABEL_1"]
171
+
172
+ if labels:
173
+ config.id2label = {i: labels[i] for i in range(len(labels))}
174
+ else:
175
+ num_labels = 2
176
+ config.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
177
+
178
+ config.label2id = {v: k for k, v in config.id2label.items()}
179
+
180
+ config.num_labels = len(config.id2label)
181
+
182
+ torch_dtype = _get_torch_dtype(precision)
183
+
184
+ model = model_class.from_pretrained(
185
+ model_name_path,
186
+ config=config,
187
+ dtype=torch_dtype,
188
+ trust_remote_code=trust_remote_code,
189
+ ignore_mismatched_sizes=True,
190
+ )
191
+
192
+ tokenizer = get_tokenizer(
193
+ model_name_path=model_name_path,
194
+ tokenizer_path=tokenizer_path,
195
+ trust_remote_code=trust_remote_code,
196
+ )
197
+
198
+ if tokenizer.pad_token is None:
199
+ if tokenizer.eos_token is not None:
200
+ tokenizer.pad_token = tokenizer.eos_token
201
+ else:
202
+ tokenizer.add_special_tokens({"pad_token": "<pad>"})
203
+
204
+ return config, tokenizer, model
@@ -0,0 +1,22 @@
1
+ import torch
2
+ from torch import device
3
+
4
+ from erictransformer.exceptions import EricDeviceError
5
+
6
+
7
+ def et_get_device() -> device:
8
+ try:
9
+ d = None
10
+ if torch.backends.mps.is_available():
11
+ if torch.backends.mps.is_built():
12
+ d = torch.device("mps")
13
+
14
+ if torch.cuda.is_available():
15
+ d = torch.device("cuda:0")
16
+
17
+ if not d:
18
+ d = torch.device("cpu")
19
+
20
+ return d
21
+ except Exception as e:
22
+ raise EricDeviceError(f"Device selection failed: {e}")
@@ -0,0 +1,15 @@
1
+ import logging
2
+ from logging import Logger
3
+
4
+
5
+ def et_get_logger() -> Logger:
6
+ logger = logging.getLogger("erictransformer")
7
+ handler = logging.StreamHandler()
8
+ handler.addFilter(logging.Filter("erictransformer"))
9
+ logging.basicConfig(
10
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
11
+ datefmt="%m/%d/%Y %H:%M:%S",
12
+ level=logging.INFO,
13
+ handlers=[handler],
14
+ )
15
+ return logger