erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
import threading
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterator, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset
|
|
7
|
+
from transformers import (
|
|
8
|
+
AutoModelForSeq2SeqLM,
|
|
9
|
+
DataCollatorForSeq2Seq,
|
|
10
|
+
GenerationConfig,
|
|
11
|
+
PretrainedConfig,
|
|
12
|
+
PreTrainedModel,
|
|
13
|
+
PreTrainedTokenizer,
|
|
14
|
+
PreTrainedTokenizerBase,
|
|
15
|
+
PreTrainedTokenizerFast,
|
|
16
|
+
TextIteratorStreamer,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from erictransformer.args import EricTrainArgs, EricEvalArgs
|
|
20
|
+
from erictransformer.eval_models import EvalModel
|
|
21
|
+
from erictransformer.exceptions import EricTokenizationError
|
|
22
|
+
from erictransformer.eric_tasks.args import (
|
|
23
|
+
TTCallArgs,
|
|
24
|
+
TTTokArgs,
|
|
25
|
+
)
|
|
26
|
+
from erictransformer.eric_tasks.misc import generate_tt_kwargs, get_pad_eos
|
|
27
|
+
from erictransformer.eric_tasks.results import TTResult
|
|
28
|
+
from erictransformer.eric_tasks.tok.tok_functions import get_max_in_len
|
|
29
|
+
from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
|
|
30
|
+
from erictransformer.loops import EvalResult
|
|
31
|
+
from erictransformer.utils import get_model_components
|
|
32
|
+
from erictransformer.validator import TTValidator
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(kw_only=True)
|
|
36
|
+
class TTStreamResult:
|
|
37
|
+
text: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class EricTextToText(EricTransformer):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model_name: Union[str, PreTrainedModel, None] = "google-t5/t5-base",
|
|
44
|
+
*,
|
|
45
|
+
trust_remote_code: bool = False,
|
|
46
|
+
tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
|
47
|
+
):
|
|
48
|
+
model_class = AutoModelForSeq2SeqLM
|
|
49
|
+
eric_args = EricTransformerArgs(
|
|
50
|
+
model_name=model_name,
|
|
51
|
+
model_class=model_class,
|
|
52
|
+
trust_remote_code=trust_remote_code,
|
|
53
|
+
tokenizer=tokenizer
|
|
54
|
+
)
|
|
55
|
+
super().__init__(eric_args)
|
|
56
|
+
|
|
57
|
+
self.task_validator = TTValidator(
|
|
58
|
+
model_name=model_name,
|
|
59
|
+
trust_remote_code=trust_remote_code,
|
|
60
|
+
tokenizer=tokenizer,
|
|
61
|
+
logger=self.logger,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if self.model is not None:
|
|
65
|
+
self.pad_token_id, self.eos_token_id = get_pad_eos(
|
|
66
|
+
self.tokenizer, self.model
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self._prep_model()
|
|
70
|
+
self._data_collator = DataCollatorForSeq2Seq(
|
|
71
|
+
self.tokenizer, model=self.model
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _get_call_thread_streamer(self, text: str, args: TTCallArgs = TTCallArgs()):
|
|
75
|
+
tokens = self.tokenizer(text, return_tensors="pt", truncation=True).to(
|
|
76
|
+
self.device
|
|
77
|
+
)
|
|
78
|
+
input_ids = tokens["input_ids"]
|
|
79
|
+
attention_mask = tokens["attention_mask"]
|
|
80
|
+
if input_ids.ndim == 1:
|
|
81
|
+
input_ids = input_ids.unsqueeze(0)
|
|
82
|
+
|
|
83
|
+
gen_streamer = TextIteratorStreamer(
|
|
84
|
+
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
gen_kwargs = generate_tt_kwargs(
|
|
88
|
+
input_ids=input_ids,
|
|
89
|
+
attention_mask=attention_mask,
|
|
90
|
+
streamer=gen_streamer,
|
|
91
|
+
args=args,
|
|
92
|
+
eos_token_id=self.eos_token_id,
|
|
93
|
+
pad_token_id=self.pad_token_id
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
gen_thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs)
|
|
97
|
+
|
|
98
|
+
return gen_thread, gen_streamer
|
|
99
|
+
|
|
100
|
+
def __call__(
|
|
101
|
+
self,
|
|
102
|
+
text: str,
|
|
103
|
+
args: TTCallArgs = TTCallArgs(),
|
|
104
|
+
) -> TTResult:
|
|
105
|
+
self._get_model_ready_inference()
|
|
106
|
+
self.task_validator.validate_call(text, args)
|
|
107
|
+
gen_thread, gen_streamer = self._get_call_thread_streamer(text, args)
|
|
108
|
+
gen_thread.start()
|
|
109
|
+
out_text = []
|
|
110
|
+
try:
|
|
111
|
+
for text in gen_streamer:
|
|
112
|
+
out_text.append(text)
|
|
113
|
+
finally:
|
|
114
|
+
gen_thread.join()
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
final_text = "".join(out_text)
|
|
118
|
+
return TTResult(text=final_text)
|
|
119
|
+
|
|
120
|
+
def stream(
|
|
121
|
+
self, text: str, args: TTCallArgs = TTCallArgs()
|
|
122
|
+
) -> Iterator[TTStreamResult]:
|
|
123
|
+
self._get_model_ready_inference()
|
|
124
|
+
self.task_validator.validate_call(text, args)
|
|
125
|
+
|
|
126
|
+
gen_thread, gen_streamer = self._get_call_thread_streamer(text, args)
|
|
127
|
+
gen_thread.start()
|
|
128
|
+
try:
|
|
129
|
+
for text in gen_streamer:
|
|
130
|
+
yield TTStreamResult(text=text)
|
|
131
|
+
finally:
|
|
132
|
+
gen_thread.join()
|
|
133
|
+
|
|
134
|
+
def _tok_function(
|
|
135
|
+
self,
|
|
136
|
+
raw_dataset,
|
|
137
|
+
args: TTTokArgs = TTTokArgs(),
|
|
138
|
+
file_type: str = "",
|
|
139
|
+
procs: Optional[int] = None,
|
|
140
|
+
) -> Dataset:
|
|
141
|
+
final_max_in_len = get_max_in_len(
|
|
142
|
+
args.max_in_len, self.tokenizer
|
|
143
|
+
)
|
|
144
|
+
final_max_out_len = get_max_in_len(
|
|
145
|
+
args.max_out_len, self.tokenizer
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def __preprocess_function(examples):
|
|
149
|
+
try:
|
|
150
|
+
model_inputs = self.tokenizer(
|
|
151
|
+
examples["input"],
|
|
152
|
+
max_length=final_max_in_len,
|
|
153
|
+
truncation=True,
|
|
154
|
+
padding="max_length",
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
labels = self.tokenizer(
|
|
158
|
+
examples["target"],
|
|
159
|
+
max_length=final_max_out_len,
|
|
160
|
+
truncation=True,
|
|
161
|
+
padding="max_length",
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
model_inputs["labels"] = [
|
|
165
|
+
[
|
|
166
|
+
(tok if tok != self.tokenizer.pad_token_id else -100)
|
|
167
|
+
for tok in seq
|
|
168
|
+
]
|
|
169
|
+
for seq in labels["input_ids"]
|
|
170
|
+
]
|
|
171
|
+
return model_inputs
|
|
172
|
+
except Exception as e:
|
|
173
|
+
raise EricTokenizationError(
|
|
174
|
+
f"Tokenization failed during preprocessing: {e}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
tok_dataset = raw_dataset.map(
|
|
179
|
+
__preprocess_function,
|
|
180
|
+
batched=True,
|
|
181
|
+
remove_columns=["input", "target"],
|
|
182
|
+
batch_size=args.bs,
|
|
183
|
+
desc="Tokenizing...",
|
|
184
|
+
num_proc=procs,
|
|
185
|
+
)
|
|
186
|
+
tok_dataset.set_format(
|
|
187
|
+
type="torch", columns=["input_ids", "attention_mask", "labels"]
|
|
188
|
+
)
|
|
189
|
+
return tok_dataset
|
|
190
|
+
except Exception as e:
|
|
191
|
+
raise EricTokenizationError(
|
|
192
|
+
f"Failed to apply preprocessing function over dataset: {e}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def train(
|
|
196
|
+
self,
|
|
197
|
+
train_path: str = "",
|
|
198
|
+
args: EricTrainArgs = EricTrainArgs(),
|
|
199
|
+
eval_path: str = "",
|
|
200
|
+
resume_path: str = "",
|
|
201
|
+
):
|
|
202
|
+
return super().train(train_path, args, eval_path, resume_path=resume_path)
|
|
203
|
+
|
|
204
|
+
def eval(self, eval_path: str = "", args=EricEvalArgs()) -> EvalResult:
|
|
205
|
+
return super().eval(eval_path=eval_path, args=args)
|
|
206
|
+
|
|
207
|
+
def tok(
|
|
208
|
+
self,
|
|
209
|
+
path: str,
|
|
210
|
+
out_dir: str,
|
|
211
|
+
args: TTTokArgs = TTTokArgs()
|
|
212
|
+
):
|
|
213
|
+
return super().tok(path=path, out_dir=out_dir, args=args)
|
|
214
|
+
|
|
215
|
+
def _load_model_components(
|
|
216
|
+
self,
|
|
217
|
+
) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
|
|
218
|
+
return get_model_components(
|
|
219
|
+
model_name_path=self.eric_args.model_name,
|
|
220
|
+
trust_remote_code=self.eric_args.trust_remote_code,
|
|
221
|
+
model_class=self.eric_args.model_class,
|
|
222
|
+
tokenizer_path=self.eric_args.tokenizer,
|
|
223
|
+
precision=self.precision_type,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def _format_tokenized_example(self, example: dict) -> dict:
|
|
227
|
+
return {
|
|
228
|
+
"input_ids": example["input_ids"],
|
|
229
|
+
"attention_mask": example["attention_mask"],
|
|
230
|
+
"labels": example["labels"],
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
def _get_default_eval_models(self) -> List[EvalModel]:
|
|
234
|
+
return []
|
|
235
|
+
|
|
236
|
+
def _prep_model(self):
|
|
237
|
+
generation_config = GenerationConfig.from_model_config(self.model.config)
|
|
238
|
+
args = TTCallArgs()
|
|
239
|
+
generation_config.num_beams = 1
|
|
240
|
+
generation_config.early_stopping = False
|
|
241
|
+
generation_config.do_sample = True
|
|
242
|
+
generation_config.min_len = args.min_len
|
|
243
|
+
generation_config.max_length = args.max_len
|
|
244
|
+
generation_config.temp = args.temp
|
|
245
|
+
generation_config.top_p = args.top_p
|
|
246
|
+
self.model.generation_config = generation_config
|
|
247
|
+
|
|
248
|
+
def _get_readme(self, repo_id: str) -> str:
|
|
249
|
+
readme_text = textwrap.dedent(f"""\
|
|
250
|
+
---
|
|
251
|
+
tags:
|
|
252
|
+
- erictransformer
|
|
253
|
+
- eric-text-to-text
|
|
254
|
+
---
|
|
255
|
+
# {repo_id}
|
|
256
|
+
|
|
257
|
+
## Installation
|
|
258
|
+
|
|
259
|
+
```
|
|
260
|
+
pip install erictransformer
|
|
261
|
+
```
|
|
262
|
+
|
|
263
|
+
## Usage
|
|
264
|
+
|
|
265
|
+
```python
|
|
266
|
+
from erictransformer import EricTextToText, TTCallArgs
|
|
267
|
+
|
|
268
|
+
eric_tt = EricTextToText(model_name="{repo_id}")
|
|
269
|
+
|
|
270
|
+
text = 'Hello world'
|
|
271
|
+
|
|
272
|
+
result = eric_tt(text)
|
|
273
|
+
print(result.text)
|
|
274
|
+
|
|
275
|
+
# Stream
|
|
276
|
+
for chunk in eric_tt.stream(text):
|
|
277
|
+
print(chunk.text, end="")
|
|
278
|
+
```
|
|
279
|
+
|
|
280
|
+
See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
|
|
281
|
+
""")
|
|
282
|
+
|
|
283
|
+
return readme_text
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.inference_mode()
|
|
7
|
+
def tc_inference(
|
|
8
|
+
tokens,
|
|
9
|
+
model,
|
|
10
|
+
id2label,
|
|
11
|
+
) -> List[List[Tuple[str, float]]]:
|
|
12
|
+
input_ids: torch.Tensor = tokens["input_ids"]
|
|
13
|
+
attention_mask: torch.Tensor = tokens["attention_mask"]
|
|
14
|
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
15
|
+
logits = outputs.logits
|
|
16
|
+
probs = torch.softmax(logits, dim=-1)
|
|
17
|
+
|
|
18
|
+
scores, indices = probs.sort(dim=-1, descending=True)
|
|
19
|
+
|
|
20
|
+
output: List[List[Tuple[str, float]]] = [
|
|
21
|
+
[
|
|
22
|
+
(id2label[int(i)], float(s))
|
|
23
|
+
for i, s in zip(idx_row.tolist(), sc_row.tolist())
|
|
24
|
+
]
|
|
25
|
+
for idx_row, sc_row in zip(indices, scores)
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
return output
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from erictransformer.eric_tasks.misc.call_utils import (
|
|
2
|
+
format_messages,
|
|
3
|
+
generate_gen_kwargs,
|
|
4
|
+
generate_tt_kwargs,
|
|
5
|
+
)
|
|
6
|
+
from erictransformer.eric_tasks.misc.get_pad_eos import get_pad_eos
|
|
7
|
+
from erictransformer.eric_tasks.misc.rag import (
|
|
8
|
+
create_search_prompt_chat,
|
|
9
|
+
formate_rag_content,
|
|
10
|
+
formate_rag_message,
|
|
11
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from transformers import TextIteratorStreamer
|
|
4
|
+
|
|
5
|
+
from erictransformer.exceptions import EricInferenceError
|
|
6
|
+
from erictransformer.eric_tasks.args import CHATCallArgs, GENCallArgs, TTCallArgs
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def format_messages(text):
|
|
10
|
+
if isinstance(text, str):
|
|
11
|
+
messages = [{"role": "user", "content": text}]
|
|
12
|
+
elif isinstance(text, list) and all(isinstance(el, dict) for el in text):
|
|
13
|
+
messages = text
|
|
14
|
+
else:
|
|
15
|
+
raise EricInferenceError("Wrong input format")
|
|
16
|
+
|
|
17
|
+
return messages
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def generate_gen_kwargs(
|
|
21
|
+
input_ids,
|
|
22
|
+
attention_mask,
|
|
23
|
+
streamer: TextIteratorStreamer,
|
|
24
|
+
args: Union[CHATCallArgs, GENCallArgs],
|
|
25
|
+
eos_token_id: int,
|
|
26
|
+
pad_token_id: int,
|
|
27
|
+
) -> dict:
|
|
28
|
+
max_len = args.max_len
|
|
29
|
+
if args.min_len > args.max_len:
|
|
30
|
+
max_len = args.min_len
|
|
31
|
+
|
|
32
|
+
gen_kwargs = dict(
|
|
33
|
+
input_ids=input_ids,
|
|
34
|
+
attention_mask=attention_mask,
|
|
35
|
+
streamer=streamer,
|
|
36
|
+
max_new_tokens=max_len,
|
|
37
|
+
temp=args.temp,
|
|
38
|
+
top_p=args.top_p,
|
|
39
|
+
eos_token_id=eos_token_id,
|
|
40
|
+
pad_token_id=pad_token_id,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return gen_kwargs
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def generate_tt_kwargs(
|
|
47
|
+
input_ids,
|
|
48
|
+
attention_mask,
|
|
49
|
+
streamer: TextIteratorStreamer,
|
|
50
|
+
args: Union[CHATCallArgs, TTCallArgs],
|
|
51
|
+
eos_token_id: int,
|
|
52
|
+
pad_token_id: int
|
|
53
|
+
) -> dict:
|
|
54
|
+
max_len = args.max_len
|
|
55
|
+
if args.min_len > args.max_len:
|
|
56
|
+
max_len = args.min_len
|
|
57
|
+
|
|
58
|
+
gen_kwargs = dict(
|
|
59
|
+
input_ids=input_ids,
|
|
60
|
+
attention_mask=attention_mask,
|
|
61
|
+
streamer=streamer,
|
|
62
|
+
max_new_tokens=max_len,
|
|
63
|
+
temp=args.temp,
|
|
64
|
+
top_p=args.top_p,
|
|
65
|
+
eos_token_id=eos_token_id,
|
|
66
|
+
pad_token_id=pad_token_id
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return gen_kwargs
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
2
|
+
|
|
3
|
+
from erictransformer.exceptions import EricInferenceError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_pad_eos(tokenizer: PreTrainedTokenizer, model: PreTrainedModel):
|
|
7
|
+
if tokenizer.pad_token_id is not None:
|
|
8
|
+
pad_id = tokenizer.pad_token_id
|
|
9
|
+
elif tokenizer.eos_token_id is not None:
|
|
10
|
+
pad_id = tokenizer.eos_token_id
|
|
11
|
+
else:
|
|
12
|
+
raise EricInferenceError(
|
|
13
|
+
"Tokenizer doesn't have a pad_token_id or eos_token_id token"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
if model.config.eos_token_id is not None:
|
|
17
|
+
eos_id = model.config.eos_token_id
|
|
18
|
+
elif tokenizer.eos_token_id is not None:
|
|
19
|
+
eos_id = tokenizer.eos_token_id
|
|
20
|
+
else:
|
|
21
|
+
raise EricInferenceError(
|
|
22
|
+
"The model and the tokenizer don't define an eos_token_id"
|
|
23
|
+
)
|
|
24
|
+
return pad_id, eos_id
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from ericsearch import RankerResult
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def formate_rag_content(text: str, data_result: RankerResult):
|
|
5
|
+
return f"Based on the search query: '{text}'. The following data may be relevant: ' {data_result.text} '"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def formate_rag_message(rag_content: str):
|
|
9
|
+
return {"role": "user", "content": rag_content}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def create_search_prompt_chat(text: str) -> str:
|
|
13
|
+
if type(text) == str:
|
|
14
|
+
search_query = text
|
|
15
|
+
else:
|
|
16
|
+
search_query = text[-1]["content"]
|
|
17
|
+
return search_query
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from erictransformer.args import CallResult
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(kw_only=True)
|
|
8
|
+
class GENResult(CallResult):
|
|
9
|
+
text: str
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(kw_only=True)
|
|
13
|
+
class CHATResult(CallResult):
|
|
14
|
+
text: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(kw_only=True)
|
|
18
|
+
class StreamCHATResult(CallResult):
|
|
19
|
+
text: str
|
|
20
|
+
mode: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(kw_only=True)
|
|
24
|
+
class TCResult(CallResult):
|
|
25
|
+
labels: List[str]
|
|
26
|
+
scores: List[float]
|
|
27
|
+
|
|
28
|
+
@dataclass(kw_only=True)
|
|
29
|
+
class TTResult(CallResult):
|
|
30
|
+
text: str
|
|
File without changes
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from datasets import Dataset
|
|
4
|
+
from transformers import PreTrainedTokenizer
|
|
5
|
+
|
|
6
|
+
from erictransformer.exceptions import EricTokenizationError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_max_in_len(
|
|
10
|
+
max_len: Union[int, None], tokenizer: PreTrainedTokenizer
|
|
11
|
+
) -> int:
|
|
12
|
+
if max_len >=1:
|
|
13
|
+
return max_len
|
|
14
|
+
else:
|
|
15
|
+
if tokenizer.model_max_length > 10_000_000:
|
|
16
|
+
return 512
|
|
17
|
+
return tokenizer.model_max_length
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def tokenize_gen(
|
|
21
|
+
tokenizer: PreTrainedTokenizer,
|
|
22
|
+
dataset: Dataset,
|
|
23
|
+
max_len: int,
|
|
24
|
+
bs: int,
|
|
25
|
+
procs: int = 1,
|
|
26
|
+
) -> Dataset:
|
|
27
|
+
def tokenize_fn(batch):
|
|
28
|
+
try:
|
|
29
|
+
tokens = tokenizer(
|
|
30
|
+
batch["text"],
|
|
31
|
+
padding="max_length",
|
|
32
|
+
truncation=True,
|
|
33
|
+
max_length=max_len,
|
|
34
|
+
)
|
|
35
|
+
labels = []
|
|
36
|
+
for input_ids in tokens["input_ids"]:
|
|
37
|
+
masked = [
|
|
38
|
+
token if token != tokenizer.pad_token_id else -100
|
|
39
|
+
for token in input_ids
|
|
40
|
+
]
|
|
41
|
+
labels.append(masked)
|
|
42
|
+
tokens["labels"] = labels
|
|
43
|
+
return tokens
|
|
44
|
+
except Exception as e:
|
|
45
|
+
raise EricTokenizationError(
|
|
46
|
+
f"Tokenization failed during batch processing. Error: {e}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
tokenized = dataset.map(
|
|
51
|
+
tokenize_fn,
|
|
52
|
+
batched=True,
|
|
53
|
+
remove_columns=["text"],
|
|
54
|
+
batch_size=bs,
|
|
55
|
+
desc="Tokenizing...",
|
|
56
|
+
num_proc=procs,
|
|
57
|
+
)
|
|
58
|
+
cols = ["input_ids", "attention_mask"]
|
|
59
|
+
cols.append("labels")
|
|
60
|
+
|
|
61
|
+
tokenized.set_format("torch", columns=cols)
|
|
62
|
+
return tokenized
|
|
63
|
+
except Exception as e:
|
|
64
|
+
raise EricTokenizationError(
|
|
65
|
+
f"Failed during dataset mapping or formatting. Error: {e}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def tokenize_chat_template(
|
|
70
|
+
tokenizer: PreTrainedTokenizer,
|
|
71
|
+
dataset: Dataset,
|
|
72
|
+
max_len: int,
|
|
73
|
+
bs: int,
|
|
74
|
+
procs: int = 1,
|
|
75
|
+
) -> Dataset:
|
|
76
|
+
def tokenize_fn(example):
|
|
77
|
+
try:
|
|
78
|
+
inputs = [msg for msg in example["messages"]]
|
|
79
|
+
tokens = tokenizer.apply_chat_template(
|
|
80
|
+
inputs,
|
|
81
|
+
tokenize=True,
|
|
82
|
+
add_generation_prompt=False,
|
|
83
|
+
max_length=max_len,
|
|
84
|
+
padding="max_length",
|
|
85
|
+
return_dict=True,
|
|
86
|
+
truncation=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
labels = []
|
|
90
|
+
for input_ids in tokens["input_ids"]:
|
|
91
|
+
masked = [
|
|
92
|
+
token if token != tokenizer.pad_token_id else -100
|
|
93
|
+
for token in input_ids
|
|
94
|
+
]
|
|
95
|
+
labels.append(masked)
|
|
96
|
+
tokens["labels"] = labels
|
|
97
|
+
|
|
98
|
+
return tokens
|
|
99
|
+
except Exception as e:
|
|
100
|
+
raise EricTokenizationError(
|
|
101
|
+
f"Tokenization failed during chat template application. Error: {e}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
tokenized = dataset.map(
|
|
106
|
+
tokenize_fn,
|
|
107
|
+
batched=True,
|
|
108
|
+
remove_columns=["messages"],
|
|
109
|
+
batch_size=bs,
|
|
110
|
+
desc="Tokenizing...",
|
|
111
|
+
num_proc=procs,
|
|
112
|
+
)
|
|
113
|
+
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
|
114
|
+
return tokenized
|
|
115
|
+
except Exception as e:
|
|
116
|
+
raise EricTokenizationError(
|
|
117
|
+
f"Failed during dataset mapping or formatting. Error: {e}"
|
|
118
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from erictransformer.eric_tracker.eric_tracker import EricTracker
|