erictransformer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. erictransformer/__init__.py +44 -0
  2. erictransformer/args/__init__.py +7 -0
  3. erictransformer/args/eric_args.py +50 -0
  4. erictransformer/eric_tasks/__init__.py +47 -0
  5. erictransformer/eric_tasks/args/__init__.py +16 -0
  6. erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
  7. erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
  8. erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
  9. erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
  10. erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
  11. erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
  12. erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
  13. erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
  14. erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
  15. erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
  16. erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
  17. erictransformer/eric_tasks/chat_templates/convert.py +67 -0
  18. erictransformer/eric_tasks/eric_chat.py +369 -0
  19. erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
  20. erictransformer/eric_tasks/eric_generation.py +243 -0
  21. erictransformer/eric_tasks/eric_text_classification.py +231 -0
  22. erictransformer/eric_tasks/eric_text_to_text.py +283 -0
  23. erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
  24. erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
  25. erictransformer/eric_tasks/misc/__init__.py +11 -0
  26. erictransformer/eric_tasks/misc/call_utils.py +69 -0
  27. erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
  28. erictransformer/eric_tasks/misc/rag.py +17 -0
  29. erictransformer/eric_tasks/results/__init__.py +6 -0
  30. erictransformer/eric_tasks/results/call_results.py +30 -0
  31. erictransformer/eric_tasks/tok/__init__.py +0 -0
  32. erictransformer/eric_tasks/tok/tok_functions.py +118 -0
  33. erictransformer/eric_tracker/__init__.py +1 -0
  34. erictransformer/eric_tracker/eric_tracker.py +256 -0
  35. erictransformer/eric_tracker/save_plot.py +422 -0
  36. erictransformer/eric_transformer.py +534 -0
  37. erictransformer/eval_models/__init__.py +1 -0
  38. erictransformer/eval_models/eval_model.py +75 -0
  39. erictransformer/exceptions/__init__.py +19 -0
  40. erictransformer/exceptions/eric_exceptions.py +74 -0
  41. erictransformer/loops/__init__.py +2 -0
  42. erictransformer/loops/eval_loop.py +111 -0
  43. erictransformer/loops/train_loop.py +310 -0
  44. erictransformer/utils/__init__.py +21 -0
  45. erictransformer/utils/init/__init__.py +5 -0
  46. erictransformer/utils/init/get_components.py +204 -0
  47. erictransformer/utils/init/get_device.py +22 -0
  48. erictransformer/utils/init/get_logger.py +15 -0
  49. erictransformer/utils/load_from_repo_or_path.py +14 -0
  50. erictransformer/utils/test/__init__.py +1 -0
  51. erictransformer/utils/test/debug_hook.py +20 -0
  52. erictransformer/utils/timer/__init__.py +1 -0
  53. erictransformer/utils/timer/eric_timer.py +145 -0
  54. erictransformer/utils/tok_data/__init__.py +8 -0
  55. erictransformer/utils/tok_data/num_proc.py +15 -0
  56. erictransformer/utils/tok_data/save_tok_data.py +36 -0
  57. erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
  58. erictransformer/utils/tok_data/tok_helpers.py +79 -0
  59. erictransformer/utils/train/__init__.py +6 -0
  60. erictransformer/utils/train/confirm_optimizer.py +18 -0
  61. erictransformer/utils/train/create_dir.py +72 -0
  62. erictransformer/utils/train/get_num_training_steps.py +15 -0
  63. erictransformer/utils/train/get_precision.py +22 -0
  64. erictransformer/utils/train/get_tok_data.py +105 -0
  65. erictransformer/utils/train/resume.py +62 -0
  66. erictransformer/validator/__init__.py +11 -0
  67. erictransformer/validator/eric/__init__.py +2 -0
  68. erictransformer/validator/eric/eval_validator.py +75 -0
  69. erictransformer/validator/eric/train_validator.py +143 -0
  70. erictransformer/validator/eric_validator.py +10 -0
  71. erictransformer/validator/tasks/__init__.py +5 -0
  72. erictransformer/validator/tasks/chat_validator.py +28 -0
  73. erictransformer/validator/tasks/gen_validator.py +28 -0
  74. erictransformer/validator/tasks/task_validator.py +54 -0
  75. erictransformer/validator/tasks/tc_validator.py +45 -0
  76. erictransformer/validator/tasks/tt_validator.py +28 -0
  77. erictransformer/validator/tok/__init__.py +1 -0
  78. erictransformer/validator/tok/tok_validator.py +23 -0
  79. erictransformer-0.0.1.dist-info/METADATA +72 -0
  80. erictransformer-0.0.1.dist-info/RECORD +83 -0
  81. erictransformer-0.0.1.dist-info/WHEEL +5 -0
  82. erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
  83. erictransformer-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,243 @@
1
+ import textwrap
2
+ import threading
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from datasets import Dataset
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ GenerationConfig,
10
+ PretrainedConfig,
11
+ PreTrainedModel,
12
+ PreTrainedTokenizerBase,
13
+ TextIteratorStreamer,
14
+ default_data_collator,
15
+ )
16
+
17
+ from erictransformer.args import EricTrainArgs, EricEvalArgs
18
+ from erictransformer.eval_models import EvalModel
19
+ from erictransformer.exceptions import EricInferenceError
20
+ from erictransformer.eric_tasks.args import (
21
+ GENCallArgs,
22
+ GENTokArgs,
23
+ )
24
+ from erictransformer.eric_tasks.misc import generate_gen_kwargs, get_pad_eos
25
+ from erictransformer.eric_tasks.results import GENResult
26
+ from erictransformer.eric_tasks.tok.tok_functions import (
27
+ get_max_in_len,
28
+ tokenize_gen,
29
+ )
30
+ from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
31
+ from erictransformer.loops import EvalResult
32
+ from erictransformer.utils import get_model_components
33
+ from erictransformer.validator import GENValidator
34
+
35
+
36
+ class EricGeneration(EricTransformer):
37
+ def __init__(
38
+ self,
39
+ model_name: Union[str, PreTrainedModel, None] = "cerebras/Cerebras-GPT-111M",
40
+ *,
41
+ trust_remote_code: bool = False,
42
+ tokenizer: Union[str, PreTrainedTokenizerBase] = None,
43
+ ):
44
+ model_class = AutoModelForCausalLM
45
+
46
+ eric_args = EricTransformerArgs(
47
+ model_name=model_name,
48
+ model_class=model_class,
49
+ trust_remote_code=trust_remote_code,
50
+ tokenizer=tokenizer
51
+ )
52
+
53
+ super().__init__(eric_args)
54
+ self.task_validator = GENValidator(
55
+ model_name=model_name,
56
+ trust_remote_code=trust_remote_code,
57
+ tokenizer=tokenizer,
58
+ logger=self.logger,
59
+ )
60
+
61
+ self._data_collator = default_data_collator
62
+
63
+ if self.model is not None:
64
+ self.pad_token_id, self.eos_token_id = get_pad_eos(
65
+ self.tokenizer, self.model
66
+ )
67
+ self._prep_model()
68
+
69
+ def _get_call_thread_streamer(self, text: str, args: GENCallArgs = GENCallArgs()):
70
+ input_ids = self.tokenizer.encode(text, return_tensors="pt")
71
+ if input_ids.ndim == 1:
72
+ input_ids = input_ids.unsqueeze(0)
73
+
74
+ input_ids = input_ids.to(self.model.device)
75
+
76
+ attention_mask = torch.ones_like(
77
+ input_ids, dtype=torch.long, device=self.model.device
78
+ )
79
+
80
+ gen_streamer = TextIteratorStreamer(
81
+ self.tokenizer, skip_prompt=True, skip_special_tokens=False
82
+ )
83
+
84
+ gen_kwargs = generate_gen_kwargs(
85
+ input_ids=input_ids,
86
+ attention_mask=attention_mask,
87
+ streamer=gen_streamer,
88
+ args=args,
89
+ eos_token_id=self.eos_token_id,
90
+ pad_token_id=self.pad_token_id,
91
+ )
92
+
93
+ gen_thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs)
94
+ return gen_thread, gen_streamer
95
+
96
+ def __call__(
97
+ self, text: str, args: GENCallArgs = GENCallArgs()
98
+ ) -> GENResult:
99
+ self._get_model_ready_inference()
100
+ gen_thread, gen_streamer = self._get_call_thread_streamer(text, args)
101
+ gen_thread.start()
102
+ out_text = []
103
+ try:
104
+ for stream_result in gen_streamer:
105
+ if stream_result:
106
+ out_text.append(stream_result)
107
+ finally:
108
+ gen_thread.join()
109
+
110
+ final_text = "".join(out_text)
111
+
112
+ return GENResult(text=final_text)
113
+
114
+ def _tok_function(
115
+ self,
116
+ raw_dataset,
117
+ args: GENTokArgs = GENTokArgs(),
118
+ file_type: str = "jsonl",
119
+ procs: Optional[int] = None,
120
+ ) -> Dataset:
121
+ max_in_len = get_max_in_len(args.max_len, self.tokenizer)
122
+
123
+ return tokenize_gen(
124
+ tokenizer=self.tokenizer,
125
+ dataset=raw_dataset,
126
+ max_len=max_in_len,
127
+ bs=args.bs,
128
+ procs=procs,
129
+ )
130
+
131
+ def train(
132
+ self,
133
+ train_path: str = "",
134
+ args: EricTrainArgs = EricTrainArgs(),
135
+ eval_path: str = "",
136
+ resume_path: str = "",
137
+ ):
138
+ return super(EricGeneration, self).train(
139
+ train_path, args, eval_path, resume_path=resume_path
140
+ )
141
+
142
+ def eval(
143
+ self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()
144
+ ) -> EvalResult:
145
+ return super(EricGeneration, self).eval(
146
+ eval_path=eval_path, args=args
147
+ )
148
+
149
+ def tok(self, path: str, out_dir: str, args: GENTokArgs = GENTokArgs()):
150
+ return super(EricGeneration, self).tok(
151
+ path=path, out_dir=out_dir, args=args
152
+ )
153
+
154
+ def _load_model_components(
155
+ self,
156
+ ) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
157
+ return get_model_components(
158
+ model_name_path=self.eric_args.model_name,
159
+ trust_remote_code=self.eric_args.trust_remote_code,
160
+ model_class=self.eric_args.model_class,
161
+ tokenizer_path=self.eric_args.tokenizer,
162
+ precision=self.precision_type,
163
+ )
164
+
165
+ def _format_tokenized_example(self, example: dict) -> dict:
166
+ return {
167
+ "input_ids": example["input_ids"],
168
+ "attention_mask": example["attention_mask"],
169
+ "labels": example["labels"],
170
+ }
171
+
172
+ def _get_default_eval_models(self) -> List[EvalModel]:
173
+ return []
174
+
175
+ def _get_model_ready(self):
176
+ self.model = self.model.to(self.device)
177
+ self.model.eval()
178
+
179
+ if self.tokenizer.pad_token_id is not None:
180
+ pad_id = self.tokenizer.pad_token_id
181
+ elif self.tokenizer.eos_token_id is not None:
182
+ pad_id = self.tokenizer.eos_token_id
183
+ else:
184
+ raise EricInferenceError(
185
+ "Tokenizer doesn't have a pad_token_id or eos_token_id token"
186
+ )
187
+
188
+ if self.model.config.eos_token_id is not None:
189
+ eos_id = self.model.config.eos_token_id
190
+ elif self.tokenizer.eos_token_id is not None:
191
+ eos_id = self.tokenizer.eos_token_id
192
+ else:
193
+ raise EricInferenceError(
194
+ "The model and the tokenizer don't't define an eos_token_id"
195
+ )
196
+ return pad_id, eos_id
197
+
198
+ def _prep_model(self):
199
+ generation_config = GenerationConfig.from_model_config(self.model.config)
200
+ args = GENCallArgs()
201
+ generation_config.num_beams = 1
202
+ generation_config.early_stopping = False
203
+ generation_config.do_sample = True
204
+ generation_config.min_len = args.min_len
205
+ generation_config.max_len = args.max_len
206
+ generation_config.temp = args.temp
207
+ generation_config.top_p = args.top_p
208
+ self.model.generation_config = generation_config
209
+
210
+ def _get_readme(self, repo_id: str) -> str:
211
+ readme_text = textwrap.dedent(f"""\
212
+ ---
213
+ tags:
214
+ - erictransformer
215
+ - eric-generation
216
+ ---
217
+
218
+ # {repo_id}
219
+
220
+ ## Installation
221
+
222
+ ```
223
+ pip install erictransformer
224
+ ```
225
+
226
+ ## Usage
227
+
228
+ ```python
229
+ from erictransformer import EricGeneration, GENCallArgs
230
+
231
+ eric_gen = EricGeneration(model_name="{repo_id}")
232
+
233
+ result = eric_gen('Hello world')
234
+
235
+ print(result.text)
236
+
237
+ # Streaming is also possible (see docs)
238
+ ```
239
+
240
+ See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
241
+ """)
242
+
243
+ return readme_text
@@ -0,0 +1,231 @@
1
+ import textwrap
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ from datasets import Dataset
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ DataCollatorWithPadding,
9
+ PretrainedConfig,
10
+ PreTrainedModel,
11
+ PreTrainedTokenizerBase,
12
+ TextClassificationPipeline,
13
+ )
14
+
15
+ from erictransformer.args import EricTrainArgs, EricEvalArgs
16
+ from erictransformer.eval_models import EvalModel, TCAccuracyEvalModel
17
+ from erictransformer.exceptions import EricInferenceError, EricTokenizationError
18
+ from erictransformer.eric_tasks.args import (
19
+ TCCallArgs,
20
+ TCTokArgs,
21
+ )
22
+ from erictransformer.eric_tasks.inference_engine.text_classification import (
23
+ tc_inference,
24
+ )
25
+ from erictransformer.eric_tasks.results import TCResult
26
+ from erictransformer.eric_tasks.tok.tok_functions import get_max_in_len
27
+ from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
28
+ from erictransformer.loops import EvalResult
29
+ from erictransformer.utils.init import get_model_components_tc
30
+ from erictransformer.validator import TCValidator
31
+
32
+
33
+ class EricTextClassification(EricTransformer):
34
+ def __init__(
35
+ self,
36
+ model_name: Union[str, PreTrainedModel, None] = "bert-base-uncased",
37
+ *,
38
+ trust_remote_code: bool = False,
39
+ tokenizer: Union[str, AutoTokenizer] = None,
40
+ labels: Optional[List[str]] = None
41
+ ):
42
+ model_class = AutoModelForSequenceClassification
43
+
44
+ self.labels = labels
45
+
46
+ eric_args = EricTransformerArgs(
47
+ model_name=model_name,
48
+ model_class=model_class,
49
+ trust_remote_code=trust_remote_code,
50
+ tokenizer=tokenizer
51
+ )
52
+
53
+ super().__init__(eric_args)
54
+
55
+ self._pipeline_class = TextClassificationPipeline
56
+
57
+ self.task_validator = TCValidator(
58
+ model_name=model_name,
59
+ trust_remote_code=trust_remote_code,
60
+ tokenizer=tokenizer,
61
+ logger=self.logger,
62
+ labels=self.labels
63
+ )
64
+
65
+ self._data_collator = DataCollatorWithPadding(self.tokenizer)
66
+
67
+ self.id2label = self.config.id2label
68
+
69
+ def __call__(self, text: str, args: TCCallArgs = TCCallArgs()) -> TCResult:
70
+ self.task_validator.validate_call(text, args)
71
+
72
+ self._get_model_ready()
73
+
74
+ tokens = self.tokenizer(
75
+ text,
76
+ return_tensors="pt",
77
+ padding=True,
78
+ truncation=True,
79
+ padding_side="left",
80
+ ).to(self.device)
81
+
82
+ try:
83
+ results = tc_inference(
84
+ tokens=tokens, model=self.model, id2label=self.id2label
85
+ )[0]
86
+
87
+ except Exception as e:
88
+ raise EricInferenceError(
89
+ f"Failed to call EricTextClassification's pipeline: {e}"
90
+ )
91
+
92
+ labels = []
93
+ scores = []
94
+ for label_and_score in results:
95
+ labels.append(label_and_score[0])
96
+ scores.append(label_and_score[1])
97
+ return TCResult(labels=labels, scores=scores)
98
+
99
+ def _tok_function(
100
+ self,
101
+ raw_dataset,
102
+ args: TCTokArgs = TCTokArgs(),
103
+ file_type: str = "",
104
+ procs: Optional[int] = None,
105
+ ) -> Dataset:
106
+ max_in_len = get_max_in_len(args.max_len, self.tokenizer)
107
+
108
+ def __preprocess_function(case):
109
+ try:
110
+ result = self.tokenizer(
111
+ case["text"],
112
+ truncation=True,
113
+ padding="max_length",
114
+ max_length=max_in_len,
115
+ )
116
+ result["labels"] = case["label"]
117
+ return result
118
+ except Exception as e:
119
+ raise EricTokenizationError(
120
+ f"Tokenization failed during preprocessing: {e}"
121
+ )
122
+
123
+ try:
124
+ tok_dataset = raw_dataset.map(
125
+ __preprocess_function,
126
+ batched=True,
127
+ remove_columns=["text", "label"],
128
+ desc="Tokenizing...",
129
+ batch_size=args.bs,
130
+ num_proc=procs,
131
+ )
132
+
133
+ tok_dataset.set_format(
134
+ type="torch", columns=["input_ids", "attention_mask", "labels"]
135
+ )
136
+ return tok_dataset
137
+ except Exception as e:
138
+ raise EricTokenizationError(
139
+ f"Failed to apply preprocessing function over dataset: {e}"
140
+ )
141
+
142
+ def train(
143
+ self,
144
+ train_path: str = "",
145
+ args: EricTrainArgs = EricTrainArgs(),
146
+ eval_path: str = "",
147
+ resume_path: str = "",
148
+ ):
149
+ return super(EricTextClassification, self).train(
150
+ train_path, args, eval_path, resume_path=resume_path
151
+ )
152
+
153
+ def eval(
154
+ self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()
155
+ ) -> EvalResult:
156
+ return super(EricTextClassification, self).eval(
157
+ eval_path=eval_path, args=args
158
+ )
159
+
160
+ def tok(
161
+ self,
162
+ path: str,
163
+ out_dir: str,
164
+ args: TCTokArgs = TCTokArgs(),
165
+ max_cases: Union[None, int] = None,
166
+ ):
167
+ return super(EricTextClassification, self).tok(
168
+ path=path, out_dir=out_dir, args=args
169
+ )
170
+
171
+ def _load_model_components(
172
+ self,
173
+ ) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
174
+ return get_model_components_tc(
175
+ model_name_path=self.eric_args.model_name,
176
+ trust_remote_code=self.eric_args.trust_remote_code,
177
+ model_class=self.eric_args.model_class,
178
+ tokenizer_path=self.eric_args.tokenizer,
179
+ labels=self.labels,
180
+ precision=self.precision_type,
181
+ )
182
+
183
+ def _format_tokenized_example(self, example: dict) -> dict:
184
+ return {
185
+ "input_ids": example["input_ids"],
186
+ "attention_mask": example["attention_mask"],
187
+ "labels": int(example["labels"]),
188
+ }
189
+
190
+ def _get_default_eval_models(self) -> List[EvalModel]:
191
+ return [TCAccuracyEvalModel()]
192
+
193
+ def _get_model_ready(self):
194
+ self.model = self.model.to(self.device)
195
+ self.model.eval()
196
+
197
+ def _prep_model(self):
198
+ pass
199
+
200
+ def _get_readme(self, repo_id: str) -> str:
201
+ readme_text = textwrap.dedent(f"""\
202
+ ---
203
+ tags:
204
+ - erictransformer
205
+ - eric-text-classification
206
+ ---
207
+ # {repo_id}
208
+
209
+ ## Installation
210
+
211
+ ```
212
+ pip install erictransformer
213
+ ```
214
+
215
+ ## Usage
216
+
217
+ ```python
218
+ from erictransformer import EricTextClassification
219
+
220
+ eric_tc = EricTextClassification(model_name="{repo_id}")
221
+
222
+ result = eric_tc('Hello world')
223
+
224
+ print(result.labels[0])
225
+ print(result.scores[0])
226
+ ```
227
+
228
+ See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
229
+ """)
230
+
231
+ return readme_text