erictransformer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. erictransformer/__init__.py +44 -0
  2. erictransformer/args/__init__.py +7 -0
  3. erictransformer/args/eric_args.py +50 -0
  4. erictransformer/eric_tasks/__init__.py +47 -0
  5. erictransformer/eric_tasks/args/__init__.py +16 -0
  6. erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
  7. erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
  8. erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
  9. erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
  10. erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
  11. erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
  12. erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
  13. erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
  14. erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
  15. erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
  16. erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
  17. erictransformer/eric_tasks/chat_templates/convert.py +67 -0
  18. erictransformer/eric_tasks/eric_chat.py +369 -0
  19. erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
  20. erictransformer/eric_tasks/eric_generation.py +243 -0
  21. erictransformer/eric_tasks/eric_text_classification.py +231 -0
  22. erictransformer/eric_tasks/eric_text_to_text.py +283 -0
  23. erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
  24. erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
  25. erictransformer/eric_tasks/misc/__init__.py +11 -0
  26. erictransformer/eric_tasks/misc/call_utils.py +69 -0
  27. erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
  28. erictransformer/eric_tasks/misc/rag.py +17 -0
  29. erictransformer/eric_tasks/results/__init__.py +6 -0
  30. erictransformer/eric_tasks/results/call_results.py +30 -0
  31. erictransformer/eric_tasks/tok/__init__.py +0 -0
  32. erictransformer/eric_tasks/tok/tok_functions.py +118 -0
  33. erictransformer/eric_tracker/__init__.py +1 -0
  34. erictransformer/eric_tracker/eric_tracker.py +256 -0
  35. erictransformer/eric_tracker/save_plot.py +422 -0
  36. erictransformer/eric_transformer.py +534 -0
  37. erictransformer/eval_models/__init__.py +1 -0
  38. erictransformer/eval_models/eval_model.py +75 -0
  39. erictransformer/exceptions/__init__.py +19 -0
  40. erictransformer/exceptions/eric_exceptions.py +74 -0
  41. erictransformer/loops/__init__.py +2 -0
  42. erictransformer/loops/eval_loop.py +111 -0
  43. erictransformer/loops/train_loop.py +310 -0
  44. erictransformer/utils/__init__.py +21 -0
  45. erictransformer/utils/init/__init__.py +5 -0
  46. erictransformer/utils/init/get_components.py +204 -0
  47. erictransformer/utils/init/get_device.py +22 -0
  48. erictransformer/utils/init/get_logger.py +15 -0
  49. erictransformer/utils/load_from_repo_or_path.py +14 -0
  50. erictransformer/utils/test/__init__.py +1 -0
  51. erictransformer/utils/test/debug_hook.py +20 -0
  52. erictransformer/utils/timer/__init__.py +1 -0
  53. erictransformer/utils/timer/eric_timer.py +145 -0
  54. erictransformer/utils/tok_data/__init__.py +8 -0
  55. erictransformer/utils/tok_data/num_proc.py +15 -0
  56. erictransformer/utils/tok_data/save_tok_data.py +36 -0
  57. erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
  58. erictransformer/utils/tok_data/tok_helpers.py +79 -0
  59. erictransformer/utils/train/__init__.py +6 -0
  60. erictransformer/utils/train/confirm_optimizer.py +18 -0
  61. erictransformer/utils/train/create_dir.py +72 -0
  62. erictransformer/utils/train/get_num_training_steps.py +15 -0
  63. erictransformer/utils/train/get_precision.py +22 -0
  64. erictransformer/utils/train/get_tok_data.py +105 -0
  65. erictransformer/utils/train/resume.py +62 -0
  66. erictransformer/validator/__init__.py +11 -0
  67. erictransformer/validator/eric/__init__.py +2 -0
  68. erictransformer/validator/eric/eval_validator.py +75 -0
  69. erictransformer/validator/eric/train_validator.py +143 -0
  70. erictransformer/validator/eric_validator.py +10 -0
  71. erictransformer/validator/tasks/__init__.py +5 -0
  72. erictransformer/validator/tasks/chat_validator.py +28 -0
  73. erictransformer/validator/tasks/gen_validator.py +28 -0
  74. erictransformer/validator/tasks/task_validator.py +54 -0
  75. erictransformer/validator/tasks/tc_validator.py +45 -0
  76. erictransformer/validator/tasks/tt_validator.py +28 -0
  77. erictransformer/validator/tok/__init__.py +1 -0
  78. erictransformer/validator/tok/tok_validator.py +23 -0
  79. erictransformer-0.0.1.dist-info/METADATA +72 -0
  80. erictransformer-0.0.1.dist-info/RECORD +83 -0
  81. erictransformer-0.0.1.dist-info/WHEEL +5 -0
  82. erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
  83. erictransformer-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,283 @@
1
+ import textwrap
2
+ import threading
3
+ from dataclasses import dataclass
4
+ from typing import Iterator, List, Optional, Tuple, Union
5
+
6
+ from datasets import Dataset
7
+ from transformers import (
8
+ AutoModelForSeq2SeqLM,
9
+ DataCollatorForSeq2Seq,
10
+ GenerationConfig,
11
+ PretrainedConfig,
12
+ PreTrainedModel,
13
+ PreTrainedTokenizer,
14
+ PreTrainedTokenizerBase,
15
+ PreTrainedTokenizerFast,
16
+ TextIteratorStreamer,
17
+ )
18
+
19
+ from erictransformer.args import EricTrainArgs, EricEvalArgs
20
+ from erictransformer.eval_models import EvalModel
21
+ from erictransformer.exceptions import EricTokenizationError
22
+ from erictransformer.eric_tasks.args import (
23
+ TTCallArgs,
24
+ TTTokArgs,
25
+ )
26
+ from erictransformer.eric_tasks.misc import generate_tt_kwargs, get_pad_eos
27
+ from erictransformer.eric_tasks.results import TTResult
28
+ from erictransformer.eric_tasks.tok.tok_functions import get_max_in_len
29
+ from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
30
+ from erictransformer.loops import EvalResult
31
+ from erictransformer.utils import get_model_components
32
+ from erictransformer.validator import TTValidator
33
+
34
+
35
+ @dataclass(kw_only=True)
36
+ class TTStreamResult:
37
+ text: str
38
+
39
+
40
+ class EricTextToText(EricTransformer):
41
+ def __init__(
42
+ self,
43
+ model_name: Union[str, PreTrainedModel, None] = "google-t5/t5-base",
44
+ *,
45
+ trust_remote_code: bool = False,
46
+ tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
47
+ ):
48
+ model_class = AutoModelForSeq2SeqLM
49
+ eric_args = EricTransformerArgs(
50
+ model_name=model_name,
51
+ model_class=model_class,
52
+ trust_remote_code=trust_remote_code,
53
+ tokenizer=tokenizer
54
+ )
55
+ super().__init__(eric_args)
56
+
57
+ self.task_validator = TTValidator(
58
+ model_name=model_name,
59
+ trust_remote_code=trust_remote_code,
60
+ tokenizer=tokenizer,
61
+ logger=self.logger,
62
+ )
63
+
64
+ if self.model is not None:
65
+ self.pad_token_id, self.eos_token_id = get_pad_eos(
66
+ self.tokenizer, self.model
67
+ )
68
+
69
+ self._prep_model()
70
+ self._data_collator = DataCollatorForSeq2Seq(
71
+ self.tokenizer, model=self.model
72
+ )
73
+
74
+ def _get_call_thread_streamer(self, text: str, args: TTCallArgs = TTCallArgs()):
75
+ tokens = self.tokenizer(text, return_tensors="pt", truncation=True).to(
76
+ self.device
77
+ )
78
+ input_ids = tokens["input_ids"]
79
+ attention_mask = tokens["attention_mask"]
80
+ if input_ids.ndim == 1:
81
+ input_ids = input_ids.unsqueeze(0)
82
+
83
+ gen_streamer = TextIteratorStreamer(
84
+ self.tokenizer, skip_prompt=True, skip_special_tokens=True
85
+ )
86
+
87
+ gen_kwargs = generate_tt_kwargs(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask,
90
+ streamer=gen_streamer,
91
+ args=args,
92
+ eos_token_id=self.eos_token_id,
93
+ pad_token_id=self.pad_token_id
94
+ )
95
+
96
+ gen_thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs)
97
+
98
+ return gen_thread, gen_streamer
99
+
100
+ def __call__(
101
+ self,
102
+ text: str,
103
+ args: TTCallArgs = TTCallArgs(),
104
+ ) -> TTResult:
105
+ self._get_model_ready_inference()
106
+ self.task_validator.validate_call(text, args)
107
+ gen_thread, gen_streamer = self._get_call_thread_streamer(text, args)
108
+ gen_thread.start()
109
+ out_text = []
110
+ try:
111
+ for text in gen_streamer:
112
+ out_text.append(text)
113
+ finally:
114
+ gen_thread.join()
115
+ pass
116
+
117
+ final_text = "".join(out_text)
118
+ return TTResult(text=final_text)
119
+
120
+ def stream(
121
+ self, text: str, args: TTCallArgs = TTCallArgs()
122
+ ) -> Iterator[TTStreamResult]:
123
+ self._get_model_ready_inference()
124
+ self.task_validator.validate_call(text, args)
125
+
126
+ gen_thread, gen_streamer = self._get_call_thread_streamer(text, args)
127
+ gen_thread.start()
128
+ try:
129
+ for text in gen_streamer:
130
+ yield TTStreamResult(text=text)
131
+ finally:
132
+ gen_thread.join()
133
+
134
+ def _tok_function(
135
+ self,
136
+ raw_dataset,
137
+ args: TTTokArgs = TTTokArgs(),
138
+ file_type: str = "",
139
+ procs: Optional[int] = None,
140
+ ) -> Dataset:
141
+ final_max_in_len = get_max_in_len(
142
+ args.max_in_len, self.tokenizer
143
+ )
144
+ final_max_out_len = get_max_in_len(
145
+ args.max_out_len, self.tokenizer
146
+ )
147
+
148
+ def __preprocess_function(examples):
149
+ try:
150
+ model_inputs = self.tokenizer(
151
+ examples["input"],
152
+ max_length=final_max_in_len,
153
+ truncation=True,
154
+ padding="max_length",
155
+ )
156
+
157
+ labels = self.tokenizer(
158
+ examples["target"],
159
+ max_length=final_max_out_len,
160
+ truncation=True,
161
+ padding="max_length",
162
+ )
163
+
164
+ model_inputs["labels"] = [
165
+ [
166
+ (tok if tok != self.tokenizer.pad_token_id else -100)
167
+ for tok in seq
168
+ ]
169
+ for seq in labels["input_ids"]
170
+ ]
171
+ return model_inputs
172
+ except Exception as e:
173
+ raise EricTokenizationError(
174
+ f"Tokenization failed during preprocessing: {e}"
175
+ )
176
+
177
+ try:
178
+ tok_dataset = raw_dataset.map(
179
+ __preprocess_function,
180
+ batched=True,
181
+ remove_columns=["input", "target"],
182
+ batch_size=args.bs,
183
+ desc="Tokenizing...",
184
+ num_proc=procs,
185
+ )
186
+ tok_dataset.set_format(
187
+ type="torch", columns=["input_ids", "attention_mask", "labels"]
188
+ )
189
+ return tok_dataset
190
+ except Exception as e:
191
+ raise EricTokenizationError(
192
+ f"Failed to apply preprocessing function over dataset: {e}"
193
+ )
194
+
195
+ def train(
196
+ self,
197
+ train_path: str = "",
198
+ args: EricTrainArgs = EricTrainArgs(),
199
+ eval_path: str = "",
200
+ resume_path: str = "",
201
+ ):
202
+ return super().train(train_path, args, eval_path, resume_path=resume_path)
203
+
204
+ def eval(self, eval_path: str = "", args=EricEvalArgs()) -> EvalResult:
205
+ return super().eval(eval_path=eval_path, args=args)
206
+
207
+ def tok(
208
+ self,
209
+ path: str,
210
+ out_dir: str,
211
+ args: TTTokArgs = TTTokArgs()
212
+ ):
213
+ return super().tok(path=path, out_dir=out_dir, args=args)
214
+
215
+ def _load_model_components(
216
+ self,
217
+ ) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
218
+ return get_model_components(
219
+ model_name_path=self.eric_args.model_name,
220
+ trust_remote_code=self.eric_args.trust_remote_code,
221
+ model_class=self.eric_args.model_class,
222
+ tokenizer_path=self.eric_args.tokenizer,
223
+ precision=self.precision_type,
224
+ )
225
+
226
+ def _format_tokenized_example(self, example: dict) -> dict:
227
+ return {
228
+ "input_ids": example["input_ids"],
229
+ "attention_mask": example["attention_mask"],
230
+ "labels": example["labels"],
231
+ }
232
+
233
+ def _get_default_eval_models(self) -> List[EvalModel]:
234
+ return []
235
+
236
+ def _prep_model(self):
237
+ generation_config = GenerationConfig.from_model_config(self.model.config)
238
+ args = TTCallArgs()
239
+ generation_config.num_beams = 1
240
+ generation_config.early_stopping = False
241
+ generation_config.do_sample = True
242
+ generation_config.min_len = args.min_len
243
+ generation_config.max_length = args.max_len
244
+ generation_config.temp = args.temp
245
+ generation_config.top_p = args.top_p
246
+ self.model.generation_config = generation_config
247
+
248
+ def _get_readme(self, repo_id: str) -> str:
249
+ readme_text = textwrap.dedent(f"""\
250
+ ---
251
+ tags:
252
+ - erictransformer
253
+ - eric-text-to-text
254
+ ---
255
+ # {repo_id}
256
+
257
+ ## Installation
258
+
259
+ ```
260
+ pip install erictransformer
261
+ ```
262
+
263
+ ## Usage
264
+
265
+ ```python
266
+ from erictransformer import EricTextToText, TTCallArgs
267
+
268
+ eric_tt = EricTextToText(model_name="{repo_id}")
269
+
270
+ text = 'Hello world'
271
+
272
+ result = eric_tt(text)
273
+ print(result.text)
274
+
275
+ # Stream
276
+ for chunk in eric_tt.stream(text):
277
+ print(chunk.text, end="")
278
+ ```
279
+
280
+ See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
281
+ """)
282
+
283
+ return readme_text
@@ -0,0 +1,3 @@
1
+ from erictransformer.eric_tasks.inference_engine.text_classification import (
2
+ tc_inference,
3
+ )
@@ -0,0 +1,28 @@
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ @torch.inference_mode()
7
+ def tc_inference(
8
+ tokens,
9
+ model,
10
+ id2label,
11
+ ) -> List[List[Tuple[str, float]]]:
12
+ input_ids: torch.Tensor = tokens["input_ids"]
13
+ attention_mask: torch.Tensor = tokens["attention_mask"]
14
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
15
+ logits = outputs.logits
16
+ probs = torch.softmax(logits, dim=-1)
17
+
18
+ scores, indices = probs.sort(dim=-1, descending=True)
19
+
20
+ output: List[List[Tuple[str, float]]] = [
21
+ [
22
+ (id2label[int(i)], float(s))
23
+ for i, s in zip(idx_row.tolist(), sc_row.tolist())
24
+ ]
25
+ for idx_row, sc_row in zip(indices, scores)
26
+ ]
27
+
28
+ return output
@@ -0,0 +1,11 @@
1
+ from erictransformer.eric_tasks.misc.call_utils import (
2
+ format_messages,
3
+ generate_gen_kwargs,
4
+ generate_tt_kwargs,
5
+ )
6
+ from erictransformer.eric_tasks.misc.get_pad_eos import get_pad_eos
7
+ from erictransformer.eric_tasks.misc.rag import (
8
+ create_search_prompt_chat,
9
+ formate_rag_content,
10
+ formate_rag_message,
11
+ )
@@ -0,0 +1,69 @@
1
+ from typing import Union
2
+
3
+ from transformers import TextIteratorStreamer
4
+
5
+ from erictransformer.exceptions import EricInferenceError
6
+ from erictransformer.eric_tasks.args import CHATCallArgs, GENCallArgs, TTCallArgs
7
+
8
+
9
+ def format_messages(text):
10
+ if isinstance(text, str):
11
+ messages = [{"role": "user", "content": text}]
12
+ elif isinstance(text, list) and all(isinstance(el, dict) for el in text):
13
+ messages = text
14
+ else:
15
+ raise EricInferenceError("Wrong input format")
16
+
17
+ return messages
18
+
19
+
20
+ def generate_gen_kwargs(
21
+ input_ids,
22
+ attention_mask,
23
+ streamer: TextIteratorStreamer,
24
+ args: Union[CHATCallArgs, GENCallArgs],
25
+ eos_token_id: int,
26
+ pad_token_id: int,
27
+ ) -> dict:
28
+ max_len = args.max_len
29
+ if args.min_len > args.max_len:
30
+ max_len = args.min_len
31
+
32
+ gen_kwargs = dict(
33
+ input_ids=input_ids,
34
+ attention_mask=attention_mask,
35
+ streamer=streamer,
36
+ max_new_tokens=max_len,
37
+ temp=args.temp,
38
+ top_p=args.top_p,
39
+ eos_token_id=eos_token_id,
40
+ pad_token_id=pad_token_id,
41
+ )
42
+
43
+ return gen_kwargs
44
+
45
+
46
+ def generate_tt_kwargs(
47
+ input_ids,
48
+ attention_mask,
49
+ streamer: TextIteratorStreamer,
50
+ args: Union[CHATCallArgs, TTCallArgs],
51
+ eos_token_id: int,
52
+ pad_token_id: int
53
+ ) -> dict:
54
+ max_len = args.max_len
55
+ if args.min_len > args.max_len:
56
+ max_len = args.min_len
57
+
58
+ gen_kwargs = dict(
59
+ input_ids=input_ids,
60
+ attention_mask=attention_mask,
61
+ streamer=streamer,
62
+ max_new_tokens=max_len,
63
+ temp=args.temp,
64
+ top_p=args.top_p,
65
+ eos_token_id=eos_token_id,
66
+ pad_token_id=pad_token_id
67
+ )
68
+
69
+ return gen_kwargs
@@ -0,0 +1,24 @@
1
+ from transformers import PreTrainedModel, PreTrainedTokenizer
2
+
3
+ from erictransformer.exceptions import EricInferenceError
4
+
5
+
6
+ def get_pad_eos(tokenizer: PreTrainedTokenizer, model: PreTrainedModel):
7
+ if tokenizer.pad_token_id is not None:
8
+ pad_id = tokenizer.pad_token_id
9
+ elif tokenizer.eos_token_id is not None:
10
+ pad_id = tokenizer.eos_token_id
11
+ else:
12
+ raise EricInferenceError(
13
+ "Tokenizer doesn't have a pad_token_id or eos_token_id token"
14
+ )
15
+
16
+ if model.config.eos_token_id is not None:
17
+ eos_id = model.config.eos_token_id
18
+ elif tokenizer.eos_token_id is not None:
19
+ eos_id = tokenizer.eos_token_id
20
+ else:
21
+ raise EricInferenceError(
22
+ "The model and the tokenizer don't define an eos_token_id"
23
+ )
24
+ return pad_id, eos_id
@@ -0,0 +1,17 @@
1
+ from ericsearch import RankerResult
2
+
3
+
4
+ def formate_rag_content(text: str, data_result: RankerResult):
5
+ return f"Based on the search query: '{text}'. The following data may be relevant: ' {data_result.text} '"
6
+
7
+
8
+ def formate_rag_message(rag_content: str):
9
+ return {"role": "user", "content": rag_content}
10
+
11
+
12
+ def create_search_prompt_chat(text: str) -> str:
13
+ if type(text) == str:
14
+ search_query = text
15
+ else:
16
+ search_query = text[-1]["content"]
17
+ return search_query
@@ -0,0 +1,6 @@
1
+ from erictransformer.eric_tasks.results.call_results import (
2
+ CHATResult,
3
+ GENResult,
4
+ TCResult,
5
+ TTResult
6
+ )
@@ -0,0 +1,30 @@
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ from erictransformer.args import CallResult
5
+
6
+
7
+ @dataclass(kw_only=True)
8
+ class GENResult(CallResult):
9
+ text: str
10
+
11
+
12
+ @dataclass(kw_only=True)
13
+ class CHATResult(CallResult):
14
+ text: str
15
+
16
+
17
+ @dataclass(kw_only=True)
18
+ class StreamCHATResult(CallResult):
19
+ text: str
20
+ mode: str
21
+
22
+
23
+ @dataclass(kw_only=True)
24
+ class TCResult(CallResult):
25
+ labels: List[str]
26
+ scores: List[float]
27
+
28
+ @dataclass(kw_only=True)
29
+ class TTResult(CallResult):
30
+ text: str
File without changes
@@ -0,0 +1,118 @@
1
+ from typing import Union
2
+
3
+ from datasets import Dataset
4
+ from transformers import PreTrainedTokenizer
5
+
6
+ from erictransformer.exceptions import EricTokenizationError
7
+
8
+
9
+ def get_max_in_len(
10
+ max_len: Union[int, None], tokenizer: PreTrainedTokenizer
11
+ ) -> int:
12
+ if max_len >=1:
13
+ return max_len
14
+ else:
15
+ if tokenizer.model_max_length > 10_000_000:
16
+ return 512
17
+ return tokenizer.model_max_length
18
+
19
+
20
+ def tokenize_gen(
21
+ tokenizer: PreTrainedTokenizer,
22
+ dataset: Dataset,
23
+ max_len: int,
24
+ bs: int,
25
+ procs: int = 1,
26
+ ) -> Dataset:
27
+ def tokenize_fn(batch):
28
+ try:
29
+ tokens = tokenizer(
30
+ batch["text"],
31
+ padding="max_length",
32
+ truncation=True,
33
+ max_length=max_len,
34
+ )
35
+ labels = []
36
+ for input_ids in tokens["input_ids"]:
37
+ masked = [
38
+ token if token != tokenizer.pad_token_id else -100
39
+ for token in input_ids
40
+ ]
41
+ labels.append(masked)
42
+ tokens["labels"] = labels
43
+ return tokens
44
+ except Exception as e:
45
+ raise EricTokenizationError(
46
+ f"Tokenization failed during batch processing. Error: {e}"
47
+ )
48
+
49
+ try:
50
+ tokenized = dataset.map(
51
+ tokenize_fn,
52
+ batched=True,
53
+ remove_columns=["text"],
54
+ batch_size=bs,
55
+ desc="Tokenizing...",
56
+ num_proc=procs,
57
+ )
58
+ cols = ["input_ids", "attention_mask"]
59
+ cols.append("labels")
60
+
61
+ tokenized.set_format("torch", columns=cols)
62
+ return tokenized
63
+ except Exception as e:
64
+ raise EricTokenizationError(
65
+ f"Failed during dataset mapping or formatting. Error: {e}"
66
+ )
67
+
68
+
69
+ def tokenize_chat_template(
70
+ tokenizer: PreTrainedTokenizer,
71
+ dataset: Dataset,
72
+ max_len: int,
73
+ bs: int,
74
+ procs: int = 1,
75
+ ) -> Dataset:
76
+ def tokenize_fn(example):
77
+ try:
78
+ inputs = [msg for msg in example["messages"]]
79
+ tokens = tokenizer.apply_chat_template(
80
+ inputs,
81
+ tokenize=True,
82
+ add_generation_prompt=False,
83
+ max_length=max_len,
84
+ padding="max_length",
85
+ return_dict=True,
86
+ truncation=True,
87
+ )
88
+
89
+ labels = []
90
+ for input_ids in tokens["input_ids"]:
91
+ masked = [
92
+ token if token != tokenizer.pad_token_id else -100
93
+ for token in input_ids
94
+ ]
95
+ labels.append(masked)
96
+ tokens["labels"] = labels
97
+
98
+ return tokens
99
+ except Exception as e:
100
+ raise EricTokenizationError(
101
+ f"Tokenization failed during chat template application. Error: {e}"
102
+ )
103
+
104
+ try:
105
+ tokenized = dataset.map(
106
+ tokenize_fn,
107
+ batched=True,
108
+ remove_columns=["messages"],
109
+ batch_size=bs,
110
+ desc="Tokenizing...",
111
+ num_proc=procs,
112
+ )
113
+ tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
114
+ return tokenized
115
+ except Exception as e:
116
+ raise EricTokenizationError(
117
+ f"Failed during dataset mapping or formatting. Error: {e}"
118
+ )
@@ -0,0 +1 @@
1
+ from erictransformer.eric_tracker.eric_tracker import EricTracker