erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
import threading
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from datasets import Dataset
|
|
7
|
+
from transformers import (
|
|
8
|
+
AutoModelForCausalLM,
|
|
9
|
+
GenerationConfig,
|
|
10
|
+
PretrainedConfig,
|
|
11
|
+
PreTrainedModel,
|
|
12
|
+
PreTrainedTokenizerBase,
|
|
13
|
+
TextIteratorStreamer,
|
|
14
|
+
default_data_collator,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from erictransformer.args import EricTrainArgs, EricEvalArgs
|
|
18
|
+
from erictransformer.eval_models import EvalModel
|
|
19
|
+
from erictransformer.exceptions import EricInferenceError
|
|
20
|
+
from erictransformer.eric_tasks.args import (
|
|
21
|
+
GENCallArgs,
|
|
22
|
+
GENTokArgs,
|
|
23
|
+
)
|
|
24
|
+
from erictransformer.eric_tasks.misc import generate_gen_kwargs, get_pad_eos
|
|
25
|
+
from erictransformer.eric_tasks.results import GENResult
|
|
26
|
+
from erictransformer.eric_tasks.tok.tok_functions import (
|
|
27
|
+
get_max_in_len,
|
|
28
|
+
tokenize_gen,
|
|
29
|
+
)
|
|
30
|
+
from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
|
|
31
|
+
from erictransformer.loops import EvalResult
|
|
32
|
+
from erictransformer.utils import get_model_components
|
|
33
|
+
from erictransformer.validator import GENValidator
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class EricGeneration(EricTransformer):
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
model_name: Union[str, PreTrainedModel, None] = "cerebras/Cerebras-GPT-111M",
|
|
40
|
+
*,
|
|
41
|
+
trust_remote_code: bool = False,
|
|
42
|
+
tokenizer: Union[str, PreTrainedTokenizerBase] = None,
|
|
43
|
+
):
|
|
44
|
+
model_class = AutoModelForCausalLM
|
|
45
|
+
|
|
46
|
+
eric_args = EricTransformerArgs(
|
|
47
|
+
model_name=model_name,
|
|
48
|
+
model_class=model_class,
|
|
49
|
+
trust_remote_code=trust_remote_code,
|
|
50
|
+
tokenizer=tokenizer
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
super().__init__(eric_args)
|
|
54
|
+
self.task_validator = GENValidator(
|
|
55
|
+
model_name=model_name,
|
|
56
|
+
trust_remote_code=trust_remote_code,
|
|
57
|
+
tokenizer=tokenizer,
|
|
58
|
+
logger=self.logger,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self._data_collator = default_data_collator
|
|
62
|
+
|
|
63
|
+
if self.model is not None:
|
|
64
|
+
self.pad_token_id, self.eos_token_id = get_pad_eos(
|
|
65
|
+
self.tokenizer, self.model
|
|
66
|
+
)
|
|
67
|
+
self._prep_model()
|
|
68
|
+
|
|
69
|
+
def _get_call_thread_streamer(self, text: str, args: GENCallArgs = GENCallArgs()):
|
|
70
|
+
input_ids = self.tokenizer.encode(text, return_tensors="pt")
|
|
71
|
+
if input_ids.ndim == 1:
|
|
72
|
+
input_ids = input_ids.unsqueeze(0)
|
|
73
|
+
|
|
74
|
+
input_ids = input_ids.to(self.model.device)
|
|
75
|
+
|
|
76
|
+
attention_mask = torch.ones_like(
|
|
77
|
+
input_ids, dtype=torch.long, device=self.model.device
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
gen_streamer = TextIteratorStreamer(
|
|
81
|
+
self.tokenizer, skip_prompt=True, skip_special_tokens=False
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
gen_kwargs = generate_gen_kwargs(
|
|
85
|
+
input_ids=input_ids,
|
|
86
|
+
attention_mask=attention_mask,
|
|
87
|
+
streamer=gen_streamer,
|
|
88
|
+
args=args,
|
|
89
|
+
eos_token_id=self.eos_token_id,
|
|
90
|
+
pad_token_id=self.pad_token_id,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
gen_thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs)
|
|
94
|
+
return gen_thread, gen_streamer
|
|
95
|
+
|
|
96
|
+
def __call__(
|
|
97
|
+
self, text: str, args: GENCallArgs = GENCallArgs()
|
|
98
|
+
) -> GENResult:
|
|
99
|
+
self._get_model_ready_inference()
|
|
100
|
+
gen_thread, gen_streamer = self._get_call_thread_streamer(text, args)
|
|
101
|
+
gen_thread.start()
|
|
102
|
+
out_text = []
|
|
103
|
+
try:
|
|
104
|
+
for stream_result in gen_streamer:
|
|
105
|
+
if stream_result:
|
|
106
|
+
out_text.append(stream_result)
|
|
107
|
+
finally:
|
|
108
|
+
gen_thread.join()
|
|
109
|
+
|
|
110
|
+
final_text = "".join(out_text)
|
|
111
|
+
|
|
112
|
+
return GENResult(text=final_text)
|
|
113
|
+
|
|
114
|
+
def _tok_function(
|
|
115
|
+
self,
|
|
116
|
+
raw_dataset,
|
|
117
|
+
args: GENTokArgs = GENTokArgs(),
|
|
118
|
+
file_type: str = "jsonl",
|
|
119
|
+
procs: Optional[int] = None,
|
|
120
|
+
) -> Dataset:
|
|
121
|
+
max_in_len = get_max_in_len(args.max_len, self.tokenizer)
|
|
122
|
+
|
|
123
|
+
return tokenize_gen(
|
|
124
|
+
tokenizer=self.tokenizer,
|
|
125
|
+
dataset=raw_dataset,
|
|
126
|
+
max_len=max_in_len,
|
|
127
|
+
bs=args.bs,
|
|
128
|
+
procs=procs,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def train(
|
|
132
|
+
self,
|
|
133
|
+
train_path: str = "",
|
|
134
|
+
args: EricTrainArgs = EricTrainArgs(),
|
|
135
|
+
eval_path: str = "",
|
|
136
|
+
resume_path: str = "",
|
|
137
|
+
):
|
|
138
|
+
return super(EricGeneration, self).train(
|
|
139
|
+
train_path, args, eval_path, resume_path=resume_path
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def eval(
|
|
143
|
+
self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()
|
|
144
|
+
) -> EvalResult:
|
|
145
|
+
return super(EricGeneration, self).eval(
|
|
146
|
+
eval_path=eval_path, args=args
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def tok(self, path: str, out_dir: str, args: GENTokArgs = GENTokArgs()):
|
|
150
|
+
return super(EricGeneration, self).tok(
|
|
151
|
+
path=path, out_dir=out_dir, args=args
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _load_model_components(
|
|
155
|
+
self,
|
|
156
|
+
) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
|
|
157
|
+
return get_model_components(
|
|
158
|
+
model_name_path=self.eric_args.model_name,
|
|
159
|
+
trust_remote_code=self.eric_args.trust_remote_code,
|
|
160
|
+
model_class=self.eric_args.model_class,
|
|
161
|
+
tokenizer_path=self.eric_args.tokenizer,
|
|
162
|
+
precision=self.precision_type,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _format_tokenized_example(self, example: dict) -> dict:
|
|
166
|
+
return {
|
|
167
|
+
"input_ids": example["input_ids"],
|
|
168
|
+
"attention_mask": example["attention_mask"],
|
|
169
|
+
"labels": example["labels"],
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
def _get_default_eval_models(self) -> List[EvalModel]:
|
|
173
|
+
return []
|
|
174
|
+
|
|
175
|
+
def _get_model_ready(self):
|
|
176
|
+
self.model = self.model.to(self.device)
|
|
177
|
+
self.model.eval()
|
|
178
|
+
|
|
179
|
+
if self.tokenizer.pad_token_id is not None:
|
|
180
|
+
pad_id = self.tokenizer.pad_token_id
|
|
181
|
+
elif self.tokenizer.eos_token_id is not None:
|
|
182
|
+
pad_id = self.tokenizer.eos_token_id
|
|
183
|
+
else:
|
|
184
|
+
raise EricInferenceError(
|
|
185
|
+
"Tokenizer doesn't have a pad_token_id or eos_token_id token"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if self.model.config.eos_token_id is not None:
|
|
189
|
+
eos_id = self.model.config.eos_token_id
|
|
190
|
+
elif self.tokenizer.eos_token_id is not None:
|
|
191
|
+
eos_id = self.tokenizer.eos_token_id
|
|
192
|
+
else:
|
|
193
|
+
raise EricInferenceError(
|
|
194
|
+
"The model and the tokenizer don't't define an eos_token_id"
|
|
195
|
+
)
|
|
196
|
+
return pad_id, eos_id
|
|
197
|
+
|
|
198
|
+
def _prep_model(self):
|
|
199
|
+
generation_config = GenerationConfig.from_model_config(self.model.config)
|
|
200
|
+
args = GENCallArgs()
|
|
201
|
+
generation_config.num_beams = 1
|
|
202
|
+
generation_config.early_stopping = False
|
|
203
|
+
generation_config.do_sample = True
|
|
204
|
+
generation_config.min_len = args.min_len
|
|
205
|
+
generation_config.max_len = args.max_len
|
|
206
|
+
generation_config.temp = args.temp
|
|
207
|
+
generation_config.top_p = args.top_p
|
|
208
|
+
self.model.generation_config = generation_config
|
|
209
|
+
|
|
210
|
+
def _get_readme(self, repo_id: str) -> str:
|
|
211
|
+
readme_text = textwrap.dedent(f"""\
|
|
212
|
+
---
|
|
213
|
+
tags:
|
|
214
|
+
- erictransformer
|
|
215
|
+
- eric-generation
|
|
216
|
+
---
|
|
217
|
+
|
|
218
|
+
# {repo_id}
|
|
219
|
+
|
|
220
|
+
## Installation
|
|
221
|
+
|
|
222
|
+
```
|
|
223
|
+
pip install erictransformer
|
|
224
|
+
```
|
|
225
|
+
|
|
226
|
+
## Usage
|
|
227
|
+
|
|
228
|
+
```python
|
|
229
|
+
from erictransformer import EricGeneration, GENCallArgs
|
|
230
|
+
|
|
231
|
+
eric_gen = EricGeneration(model_name="{repo_id}")
|
|
232
|
+
|
|
233
|
+
result = eric_gen('Hello world')
|
|
234
|
+
|
|
235
|
+
print(result.text)
|
|
236
|
+
|
|
237
|
+
# Streaming is also possible (see docs)
|
|
238
|
+
```
|
|
239
|
+
|
|
240
|
+
See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
|
|
241
|
+
""")
|
|
242
|
+
|
|
243
|
+
return readme_text
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
from transformers import (
|
|
6
|
+
AutoModelForSequenceClassification,
|
|
7
|
+
AutoTokenizer,
|
|
8
|
+
DataCollatorWithPadding,
|
|
9
|
+
PretrainedConfig,
|
|
10
|
+
PreTrainedModel,
|
|
11
|
+
PreTrainedTokenizerBase,
|
|
12
|
+
TextClassificationPipeline,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from erictransformer.args import EricTrainArgs, EricEvalArgs
|
|
16
|
+
from erictransformer.eval_models import EvalModel, TCAccuracyEvalModel
|
|
17
|
+
from erictransformer.exceptions import EricInferenceError, EricTokenizationError
|
|
18
|
+
from erictransformer.eric_tasks.args import (
|
|
19
|
+
TCCallArgs,
|
|
20
|
+
TCTokArgs,
|
|
21
|
+
)
|
|
22
|
+
from erictransformer.eric_tasks.inference_engine.text_classification import (
|
|
23
|
+
tc_inference,
|
|
24
|
+
)
|
|
25
|
+
from erictransformer.eric_tasks.results import TCResult
|
|
26
|
+
from erictransformer.eric_tasks.tok.tok_functions import get_max_in_len
|
|
27
|
+
from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
|
|
28
|
+
from erictransformer.loops import EvalResult
|
|
29
|
+
from erictransformer.utils.init import get_model_components_tc
|
|
30
|
+
from erictransformer.validator import TCValidator
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EricTextClassification(EricTransformer):
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model_name: Union[str, PreTrainedModel, None] = "bert-base-uncased",
|
|
37
|
+
*,
|
|
38
|
+
trust_remote_code: bool = False,
|
|
39
|
+
tokenizer: Union[str, AutoTokenizer] = None,
|
|
40
|
+
labels: Optional[List[str]] = None
|
|
41
|
+
):
|
|
42
|
+
model_class = AutoModelForSequenceClassification
|
|
43
|
+
|
|
44
|
+
self.labels = labels
|
|
45
|
+
|
|
46
|
+
eric_args = EricTransformerArgs(
|
|
47
|
+
model_name=model_name,
|
|
48
|
+
model_class=model_class,
|
|
49
|
+
trust_remote_code=trust_remote_code,
|
|
50
|
+
tokenizer=tokenizer
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
super().__init__(eric_args)
|
|
54
|
+
|
|
55
|
+
self._pipeline_class = TextClassificationPipeline
|
|
56
|
+
|
|
57
|
+
self.task_validator = TCValidator(
|
|
58
|
+
model_name=model_name,
|
|
59
|
+
trust_remote_code=trust_remote_code,
|
|
60
|
+
tokenizer=tokenizer,
|
|
61
|
+
logger=self.logger,
|
|
62
|
+
labels=self.labels
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self._data_collator = DataCollatorWithPadding(self.tokenizer)
|
|
66
|
+
|
|
67
|
+
self.id2label = self.config.id2label
|
|
68
|
+
|
|
69
|
+
def __call__(self, text: str, args: TCCallArgs = TCCallArgs()) -> TCResult:
|
|
70
|
+
self.task_validator.validate_call(text, args)
|
|
71
|
+
|
|
72
|
+
self._get_model_ready()
|
|
73
|
+
|
|
74
|
+
tokens = self.tokenizer(
|
|
75
|
+
text,
|
|
76
|
+
return_tensors="pt",
|
|
77
|
+
padding=True,
|
|
78
|
+
truncation=True,
|
|
79
|
+
padding_side="left",
|
|
80
|
+
).to(self.device)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
results = tc_inference(
|
|
84
|
+
tokens=tokens, model=self.model, id2label=self.id2label
|
|
85
|
+
)[0]
|
|
86
|
+
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise EricInferenceError(
|
|
89
|
+
f"Failed to call EricTextClassification's pipeline: {e}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
labels = []
|
|
93
|
+
scores = []
|
|
94
|
+
for label_and_score in results:
|
|
95
|
+
labels.append(label_and_score[0])
|
|
96
|
+
scores.append(label_and_score[1])
|
|
97
|
+
return TCResult(labels=labels, scores=scores)
|
|
98
|
+
|
|
99
|
+
def _tok_function(
|
|
100
|
+
self,
|
|
101
|
+
raw_dataset,
|
|
102
|
+
args: TCTokArgs = TCTokArgs(),
|
|
103
|
+
file_type: str = "",
|
|
104
|
+
procs: Optional[int] = None,
|
|
105
|
+
) -> Dataset:
|
|
106
|
+
max_in_len = get_max_in_len(args.max_len, self.tokenizer)
|
|
107
|
+
|
|
108
|
+
def __preprocess_function(case):
|
|
109
|
+
try:
|
|
110
|
+
result = self.tokenizer(
|
|
111
|
+
case["text"],
|
|
112
|
+
truncation=True,
|
|
113
|
+
padding="max_length",
|
|
114
|
+
max_length=max_in_len,
|
|
115
|
+
)
|
|
116
|
+
result["labels"] = case["label"]
|
|
117
|
+
return result
|
|
118
|
+
except Exception as e:
|
|
119
|
+
raise EricTokenizationError(
|
|
120
|
+
f"Tokenization failed during preprocessing: {e}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
tok_dataset = raw_dataset.map(
|
|
125
|
+
__preprocess_function,
|
|
126
|
+
batched=True,
|
|
127
|
+
remove_columns=["text", "label"],
|
|
128
|
+
desc="Tokenizing...",
|
|
129
|
+
batch_size=args.bs,
|
|
130
|
+
num_proc=procs,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
tok_dataset.set_format(
|
|
134
|
+
type="torch", columns=["input_ids", "attention_mask", "labels"]
|
|
135
|
+
)
|
|
136
|
+
return tok_dataset
|
|
137
|
+
except Exception as e:
|
|
138
|
+
raise EricTokenizationError(
|
|
139
|
+
f"Failed to apply preprocessing function over dataset: {e}"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def train(
|
|
143
|
+
self,
|
|
144
|
+
train_path: str = "",
|
|
145
|
+
args: EricTrainArgs = EricTrainArgs(),
|
|
146
|
+
eval_path: str = "",
|
|
147
|
+
resume_path: str = "",
|
|
148
|
+
):
|
|
149
|
+
return super(EricTextClassification, self).train(
|
|
150
|
+
train_path, args, eval_path, resume_path=resume_path
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def eval(
|
|
154
|
+
self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()
|
|
155
|
+
) -> EvalResult:
|
|
156
|
+
return super(EricTextClassification, self).eval(
|
|
157
|
+
eval_path=eval_path, args=args
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def tok(
|
|
161
|
+
self,
|
|
162
|
+
path: str,
|
|
163
|
+
out_dir: str,
|
|
164
|
+
args: TCTokArgs = TCTokArgs(),
|
|
165
|
+
max_cases: Union[None, int] = None,
|
|
166
|
+
):
|
|
167
|
+
return super(EricTextClassification, self).tok(
|
|
168
|
+
path=path, out_dir=out_dir, args=args
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def _load_model_components(
|
|
172
|
+
self,
|
|
173
|
+
) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
|
|
174
|
+
return get_model_components_tc(
|
|
175
|
+
model_name_path=self.eric_args.model_name,
|
|
176
|
+
trust_remote_code=self.eric_args.trust_remote_code,
|
|
177
|
+
model_class=self.eric_args.model_class,
|
|
178
|
+
tokenizer_path=self.eric_args.tokenizer,
|
|
179
|
+
labels=self.labels,
|
|
180
|
+
precision=self.precision_type,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def _format_tokenized_example(self, example: dict) -> dict:
|
|
184
|
+
return {
|
|
185
|
+
"input_ids": example["input_ids"],
|
|
186
|
+
"attention_mask": example["attention_mask"],
|
|
187
|
+
"labels": int(example["labels"]),
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
def _get_default_eval_models(self) -> List[EvalModel]:
|
|
191
|
+
return [TCAccuracyEvalModel()]
|
|
192
|
+
|
|
193
|
+
def _get_model_ready(self):
|
|
194
|
+
self.model = self.model.to(self.device)
|
|
195
|
+
self.model.eval()
|
|
196
|
+
|
|
197
|
+
def _prep_model(self):
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
def _get_readme(self, repo_id: str) -> str:
|
|
201
|
+
readme_text = textwrap.dedent(f"""\
|
|
202
|
+
---
|
|
203
|
+
tags:
|
|
204
|
+
- erictransformer
|
|
205
|
+
- eric-text-classification
|
|
206
|
+
---
|
|
207
|
+
# {repo_id}
|
|
208
|
+
|
|
209
|
+
## Installation
|
|
210
|
+
|
|
211
|
+
```
|
|
212
|
+
pip install erictransformer
|
|
213
|
+
```
|
|
214
|
+
|
|
215
|
+
## Usage
|
|
216
|
+
|
|
217
|
+
```python
|
|
218
|
+
from erictransformer import EricTextClassification
|
|
219
|
+
|
|
220
|
+
eric_tc = EricTextClassification(model_name="{repo_id}")
|
|
221
|
+
|
|
222
|
+
result = eric_tc('Hello world')
|
|
223
|
+
|
|
224
|
+
print(result.labels[0])
|
|
225
|
+
print(result.scores[0])
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
|
|
229
|
+
""")
|
|
230
|
+
|
|
231
|
+
return readme_text
|