erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
import threading
|
|
3
|
+
import traceback
|
|
4
|
+
from typing import Iterator, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from datasets import Dataset
|
|
8
|
+
from transformers import (
|
|
9
|
+
AutoModelForCausalLM,
|
|
10
|
+
GenerationConfig,
|
|
11
|
+
PretrainedConfig,
|
|
12
|
+
PreTrainedModel,
|
|
13
|
+
PreTrainedTokenizerBase,
|
|
14
|
+
TextIteratorStreamer,
|
|
15
|
+
default_data_collator,
|
|
16
|
+
)
|
|
17
|
+
from ericsearch import EricSearch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
from erictransformer.eval_models import EvalModel
|
|
21
|
+
from erictransformer.args import EricTrainArgs, EricEvalArgs
|
|
22
|
+
from erictransformer.eric_tasks.args import (
|
|
23
|
+
CHATCallArgs,
|
|
24
|
+
CHATTokArgs,
|
|
25
|
+
)
|
|
26
|
+
from erictransformer.eric_tasks.chat_stream_handlers import (
|
|
27
|
+
CHATStreamResult,
|
|
28
|
+
DefaultStreamHandler,
|
|
29
|
+
GPTOSSSMHandler,
|
|
30
|
+
SmolStreamHandler,
|
|
31
|
+
)
|
|
32
|
+
from erictransformer.eric_tasks.chat_templates import map_chat_roles
|
|
33
|
+
from erictransformer.eric_tasks.misc import (
|
|
34
|
+
create_search_prompt_chat,
|
|
35
|
+
format_messages,
|
|
36
|
+
formate_rag_content,
|
|
37
|
+
formate_rag_message,
|
|
38
|
+
generate_gen_kwargs,
|
|
39
|
+
get_pad_eos,
|
|
40
|
+
)
|
|
41
|
+
from erictransformer.eric_tasks.results import CHATResult
|
|
42
|
+
from erictransformer.eric_tasks.tok.tok_functions import (
|
|
43
|
+
get_max_in_len,
|
|
44
|
+
tokenize_chat_template,
|
|
45
|
+
)
|
|
46
|
+
from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
|
|
47
|
+
from erictransformer.loops import EvalResult
|
|
48
|
+
from erictransformer.utils import get_model_components
|
|
49
|
+
from erictransformer.validator import CHATValidator
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class EricChat(EricTransformer):
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
model_name: Union[str, PreTrainedModel, None] = "openai/gpt-oss-20b",
|
|
56
|
+
*,
|
|
57
|
+
trust_remote_code: bool = False,
|
|
58
|
+
tokenizer: Union[str, PreTrainedTokenizerBase] = None,
|
|
59
|
+
eric_search: Optional[EricSearch] = None,
|
|
60
|
+
):
|
|
61
|
+
model_class = AutoModelForCausalLM
|
|
62
|
+
|
|
63
|
+
eric_args = EricTransformerArgs(
|
|
64
|
+
model_name=model_name,
|
|
65
|
+
model_class=model_class,
|
|
66
|
+
trust_remote_code=trust_remote_code,
|
|
67
|
+
tokenizer=tokenizer
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
super().__init__(eric_args)
|
|
71
|
+
|
|
72
|
+
self.task_validator = CHATValidator(
|
|
73
|
+
model_name=model_name,
|
|
74
|
+
trust_remote_code=trust_remote_code,
|
|
75
|
+
tokenizer=tokenizer,
|
|
76
|
+
logger=self.logger,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if not getattr(self.tokenizer, "chat_template", None):
|
|
80
|
+
raise ValueError("The tokenizer must include a chat template")
|
|
81
|
+
|
|
82
|
+
self._data_collator = default_data_collator
|
|
83
|
+
|
|
84
|
+
self.logger.info("Using tokenizer's built-in chat template.")
|
|
85
|
+
if model_name:
|
|
86
|
+
self.config = self.model.config
|
|
87
|
+
|
|
88
|
+
if self.model is not None:
|
|
89
|
+
self.pad_token_id, self.eos_token_id = get_pad_eos(
|
|
90
|
+
self.tokenizer, self.model
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if (
|
|
94
|
+
model_name is not None
|
|
95
|
+
): # we don't need to initialize these if a model is not provided.
|
|
96
|
+
if self.config.model_type == "smollm3":
|
|
97
|
+
self.text_streamer_handler = SmolStreamHandler(self.tokenizer, self.logger)
|
|
98
|
+
self.model_type = "smollm3"
|
|
99
|
+
elif self.config.model_type == "gpt_oss":
|
|
100
|
+
self.text_streamer_handler = GPTOSSSMHandler(self.tokenizer, self.logger)
|
|
101
|
+
self.model_type = "gpt_oss"
|
|
102
|
+
else:
|
|
103
|
+
self.text_streamer_handler = DefaultStreamHandler(self.tokenizer)
|
|
104
|
+
self.model_type = "default"
|
|
105
|
+
|
|
106
|
+
if eric_search:
|
|
107
|
+
self.eric_search = eric_search
|
|
108
|
+
else:
|
|
109
|
+
self.eric_search = None
|
|
110
|
+
|
|
111
|
+
if self.model:
|
|
112
|
+
self._prep_model()
|
|
113
|
+
|
|
114
|
+
self.to_stream_tokens = []
|
|
115
|
+
|
|
116
|
+
def _get_call_thread_streamer(
|
|
117
|
+
self, messages: list[dict], args: CHATCallArgs = CHATCallArgs()
|
|
118
|
+
):
|
|
119
|
+
mapped_messages = map_chat_roles(
|
|
120
|
+
messages=messages,
|
|
121
|
+
model_name=self.eric_args.model_name,
|
|
122
|
+
model_type=self.model_type,
|
|
123
|
+
)
|
|
124
|
+
if self.model_type != "gpt_oss":
|
|
125
|
+
|
|
126
|
+
input_ids = self.tokenizer.apply_chat_template(
|
|
127
|
+
mapped_messages, add_generation_prompt=True, return_tensors="pt"
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
131
|
+
mapped_messages, add_generation_prompt=True, tokenize=False
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
prompt += "<|channel|>analysis<|message|>"
|
|
135
|
+
self.to_stream_tokens.append(self.text_streamer_handler.step("<|channel|>"))
|
|
136
|
+
self.to_stream_tokens.append(self.text_streamer_handler.step("analysis"))
|
|
137
|
+
self.to_stream_tokens.append(self.text_streamer_handler.step("<|message|>"))
|
|
138
|
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
|
139
|
+
|
|
140
|
+
if input_ids.ndim == 1:
|
|
141
|
+
input_ids = input_ids.unsqueeze(0)
|
|
142
|
+
|
|
143
|
+
input_ids = input_ids.to(self.model.device)
|
|
144
|
+
|
|
145
|
+
attention_mask = torch.ones_like(
|
|
146
|
+
input_ids, dtype=torch.long, device=self.model.device
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
gen_streamer = TextIteratorStreamer(
|
|
150
|
+
self.tokenizer, skip_prompt=True, skip_special_tokens=False
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
gen_kwargs = generate_gen_kwargs(
|
|
154
|
+
input_ids=input_ids,
|
|
155
|
+
attention_mask=attention_mask,
|
|
156
|
+
streamer=gen_streamer,
|
|
157
|
+
args=args,
|
|
158
|
+
eos_token_id=self.eos_token_id,
|
|
159
|
+
pad_token_id=self.pad_token_id,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def thread_fn():
|
|
163
|
+
try:
|
|
164
|
+
self.model.generate(**gen_kwargs)
|
|
165
|
+
except BaseException:
|
|
166
|
+
# Bare except here is OK because we are in background thread
|
|
167
|
+
# and therefore don't need to worry about
|
|
168
|
+
# muting BaseExceptions such as KeyboardInterrupt
|
|
169
|
+
|
|
170
|
+
err_str = traceback.format_exc()
|
|
171
|
+
self.logger.error(f"Error in generate thread: {err_str}")
|
|
172
|
+
gen_streamer.end()
|
|
173
|
+
|
|
174
|
+
gen_thread = threading.Thread(target=thread_fn)
|
|
175
|
+
return gen_thread, gen_streamer
|
|
176
|
+
|
|
177
|
+
def __call__(
|
|
178
|
+
self, text: Union[str, List[dict]], args: CHATCallArgs = CHATCallArgs()
|
|
179
|
+
) -> CHATResult:
|
|
180
|
+
messages = format_messages(text)
|
|
181
|
+
|
|
182
|
+
if self.eric_search is not None:
|
|
183
|
+
search_query = create_search_prompt_chat(text)
|
|
184
|
+
data_result = self.eric_search(search_query, args=args.search_args)
|
|
185
|
+
|
|
186
|
+
if data_result:
|
|
187
|
+
top_result = data_result[0]
|
|
188
|
+
rag_content = formate_rag_content(text=text, data_result=top_result)
|
|
189
|
+
rag_message = formate_rag_message(rag_content=rag_content)
|
|
190
|
+
messages.insert(-1, rag_message)
|
|
191
|
+
|
|
192
|
+
self._get_model_ready_inference()
|
|
193
|
+
gen_thread, gen_streamer = self._get_call_thread_streamer(messages, args)
|
|
194
|
+
gen_thread.start()
|
|
195
|
+
out_text = []
|
|
196
|
+
try:
|
|
197
|
+
for text in gen_streamer:
|
|
198
|
+
tokens = self.tokenizer.encode(text)
|
|
199
|
+
for token in tokens:
|
|
200
|
+
token_string = self.tokenizer.decode(token)
|
|
201
|
+
stream_result = self.text_streamer_handler.step(token_string)
|
|
202
|
+
if stream_result:
|
|
203
|
+
if stream_result.marker == "text":
|
|
204
|
+
out_text.append(stream_result.text)
|
|
205
|
+
finally:
|
|
206
|
+
gen_thread.join()
|
|
207
|
+
|
|
208
|
+
final_text = "".join(out_text)
|
|
209
|
+
|
|
210
|
+
return CHATResult(text=final_text)
|
|
211
|
+
|
|
212
|
+
def stream(
|
|
213
|
+
self, text: Union[str, List[dict]], args: CHATCallArgs = CHATCallArgs()
|
|
214
|
+
) -> Iterator[CHATStreamResult]:
|
|
215
|
+
messages = format_messages(text)
|
|
216
|
+
|
|
217
|
+
if self.eric_search is not None:
|
|
218
|
+
search_query = create_search_prompt_chat(text)
|
|
219
|
+
yield CHATStreamResult(
|
|
220
|
+
text="", marker="search", payload={"query": search_query}
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
data_result = self.eric_search(search_query, args=args.search_args)
|
|
224
|
+
|
|
225
|
+
if data_result:
|
|
226
|
+
top_result = data_result[0]
|
|
227
|
+
yield CHATStreamResult(
|
|
228
|
+
text="",
|
|
229
|
+
marker="search_result",
|
|
230
|
+
payload={
|
|
231
|
+
"text": top_result.text,
|
|
232
|
+
"best_sentence": top_result.best_sentence,
|
|
233
|
+
"metadata": top_result.metadata,
|
|
234
|
+
},
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
rag_content = formate_rag_content(text=text, data_result=top_result)
|
|
238
|
+
rag_message = formate_rag_message(rag_content=rag_content)
|
|
239
|
+
messages.insert(-1, rag_message)
|
|
240
|
+
|
|
241
|
+
self._get_model_ready_inference()
|
|
242
|
+
|
|
243
|
+
gen_thread, gen_streamer = self._get_call_thread_streamer(messages, args)
|
|
244
|
+
|
|
245
|
+
while self.to_stream_tokens:
|
|
246
|
+
stream_result = self.to_stream_tokens.pop(0)
|
|
247
|
+
if stream_result:
|
|
248
|
+
yield stream_result
|
|
249
|
+
|
|
250
|
+
gen_thread.start()
|
|
251
|
+
try:
|
|
252
|
+
for text in gen_streamer:
|
|
253
|
+
tokens = self.tokenizer.encode(text)
|
|
254
|
+
for token in tokens:
|
|
255
|
+
token_string = self.tokenizer.decode(token)
|
|
256
|
+
stream_result = self.text_streamer_handler.step(token_string)
|
|
257
|
+
if stream_result:
|
|
258
|
+
yield stream_result
|
|
259
|
+
|
|
260
|
+
finally:
|
|
261
|
+
gen_thread.join()
|
|
262
|
+
|
|
263
|
+
def _tok_function(
|
|
264
|
+
self,
|
|
265
|
+
raw_dataset,
|
|
266
|
+
args: CHATTokArgs = CHATTokArgs(),
|
|
267
|
+
file_type: str = "jsonl",
|
|
268
|
+
procs: Optional[int] = None,
|
|
269
|
+
) -> Dataset:
|
|
270
|
+
max_in_len = get_max_in_len(args.max_len, self.tokenizer)
|
|
271
|
+
|
|
272
|
+
return tokenize_chat_template(
|
|
273
|
+
tokenizer=self.tokenizer,
|
|
274
|
+
dataset=raw_dataset,
|
|
275
|
+
max_len=max_in_len,
|
|
276
|
+
bs=args.bs,
|
|
277
|
+
procs=procs,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def train(
|
|
281
|
+
self,
|
|
282
|
+
train_path: str = "",
|
|
283
|
+
args: EricTrainArgs = EricTrainArgs(),
|
|
284
|
+
eval_path: str = "",
|
|
285
|
+
*,
|
|
286
|
+
resume_path: str = "",
|
|
287
|
+
):
|
|
288
|
+
return super(EricChat, self).train(
|
|
289
|
+
train_path, args, eval_path, resume_path=resume_path
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def eval(
|
|
293
|
+
self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()
|
|
294
|
+
) -> EvalResult:
|
|
295
|
+
return super(EricChat, self).eval(eval_path=eval_path, args=args)
|
|
296
|
+
|
|
297
|
+
def tok(self, path: str, out_dir: str, args: CHATTokArgs = CHATTokArgs()):
|
|
298
|
+
return super(EricChat, self).tok(
|
|
299
|
+
path=path, out_dir=out_dir, args=args
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def _load_model_components(
|
|
303
|
+
self,
|
|
304
|
+
) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
|
|
305
|
+
return get_model_components(
|
|
306
|
+
model_name_path=self.eric_args.model_name,
|
|
307
|
+
trust_remote_code=self.eric_args.trust_remote_code,
|
|
308
|
+
model_class=self.eric_args.model_class,
|
|
309
|
+
tokenizer_path=self.eric_args.tokenizer,
|
|
310
|
+
precision=self.precision_type,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def _format_tokenized_example(self, example: dict) -> dict:
|
|
314
|
+
return {
|
|
315
|
+
"input_ids": example["input_ids"],
|
|
316
|
+
"attention_mask": example["attention_mask"],
|
|
317
|
+
"labels": example["labels"],
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
def _get_default_eval_models(self) -> List[EvalModel]:
|
|
321
|
+
return []
|
|
322
|
+
|
|
323
|
+
def _prep_model(self):
|
|
324
|
+
args = CHATCallArgs()
|
|
325
|
+
generation_config = GenerationConfig.from_model_config(self.model.config)
|
|
326
|
+
generation_config.num_beams = 1
|
|
327
|
+
generation_config.early_stopping = False
|
|
328
|
+
generation_config.do_sample = True
|
|
329
|
+
generation_config.min_len = args.min_len
|
|
330
|
+
generation_config.max_len = args.max_len
|
|
331
|
+
generation_config.temp = args.temp
|
|
332
|
+
generation_config.top_p = args.top_p
|
|
333
|
+
self.model.generation_config = generation_config
|
|
334
|
+
|
|
335
|
+
def _get_readme(self, repo_id: str) -> str:
|
|
336
|
+
readme_text = textwrap.dedent(f"""\
|
|
337
|
+
---
|
|
338
|
+
tags:
|
|
339
|
+
- erictransformer
|
|
340
|
+
- eric-chat
|
|
341
|
+
---
|
|
342
|
+
# {repo_id}
|
|
343
|
+
|
|
344
|
+
## Installation:
|
|
345
|
+
|
|
346
|
+
```
|
|
347
|
+
pip install erictransformer
|
|
348
|
+
```
|
|
349
|
+
|
|
350
|
+
## Usage
|
|
351
|
+
|
|
352
|
+
```python
|
|
353
|
+
from erictransformer import EricChat, CHATCallArgs
|
|
354
|
+
|
|
355
|
+
eric_chat = EricChat(model_name="{repo_id}")
|
|
356
|
+
|
|
357
|
+
text = 'Hello world'
|
|
358
|
+
|
|
359
|
+
result = eric_chat(text)
|
|
360
|
+
print(result.text)
|
|
361
|
+
|
|
362
|
+
# Streaming is also possible (see docs)
|
|
363
|
+
```
|
|
364
|
+
|
|
365
|
+
See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
|
|
366
|
+
|
|
367
|
+
""")
|
|
368
|
+
|
|
369
|
+
return readme_text
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
import textwrap
|
|
4
|
+
from typing import Iterator, List, Optional, Union
|
|
5
|
+
|
|
6
|
+
from huggingface_hub import HfApi
|
|
7
|
+
from transformers import AutoConfig
|
|
8
|
+
from ericsearch import EricSearch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from mlx_lm import load, stream_generate
|
|
13
|
+
from mlx_lm.sample_utils import make_sampler
|
|
14
|
+
from mlx_lm.utils import save_model
|
|
15
|
+
except ImportError as err:
|
|
16
|
+
raise ImportError("""
|
|
17
|
+
Failed to import MLX. If you are using a Mac,
|
|
18
|
+
try `pip install mlx-lm`. If you have a CUDA-compatible device,
|
|
19
|
+
try `pip install mlx[cuda] mlx-lm`. Otherwise, try `pip install mlx[cpu] mlx-lm`
|
|
20
|
+
""") from err
|
|
21
|
+
|
|
22
|
+
from erictransformer.exceptions import (
|
|
23
|
+
EricNoModelError,
|
|
24
|
+
EricPushError,
|
|
25
|
+
EricSaveError,
|
|
26
|
+
)
|
|
27
|
+
from erictransformer.eric_tasks.args import CHATCallArgs
|
|
28
|
+
from erictransformer.eric_tasks.chat_stream_handlers import (
|
|
29
|
+
CHATStreamResult,
|
|
30
|
+
DefaultStreamHandler,
|
|
31
|
+
GPTOSSSMHandler,
|
|
32
|
+
SmolStreamHandler,
|
|
33
|
+
)
|
|
34
|
+
from erictransformer.eric_tasks.chat_templates import map_chat_roles
|
|
35
|
+
from erictransformer.eric_tasks.misc import (
|
|
36
|
+
create_search_prompt_chat,
|
|
37
|
+
format_messages,
|
|
38
|
+
formate_rag_content,
|
|
39
|
+
formate_rag_message,
|
|
40
|
+
)
|
|
41
|
+
from erictransformer.eric_tasks.results import CHATResult
|
|
42
|
+
from erictransformer.utils import et_get_logger
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EricChatMLX:
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
model_name: str = "mlx-community/SmolLM3-3B-4bit",
|
|
49
|
+
*,
|
|
50
|
+
eric_search: Optional[EricSearch] = None,
|
|
51
|
+
):
|
|
52
|
+
self.model_name = model_name
|
|
53
|
+
self.model, self.tokenizer = load(model_name)
|
|
54
|
+
self.logger = et_get_logger()
|
|
55
|
+
|
|
56
|
+
if not getattr(self.tokenizer, "chat_template", None):
|
|
57
|
+
raise ValueError("The tokenizer must include a chat template")
|
|
58
|
+
|
|
59
|
+
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=False)
|
|
60
|
+
|
|
61
|
+
if self.config.model_type == "smollm3":
|
|
62
|
+
self.text_streamer_handler = SmolStreamHandler(self.tokenizer, self.logger)
|
|
63
|
+
self.model_type = "smollm3"
|
|
64
|
+
elif self.config.model_type == "gpt_oss":
|
|
65
|
+
self.text_streamer_handler = GPTOSSSMHandler(self.tokenizer, self.logger)
|
|
66
|
+
self.model_type = "gpt_oss"
|
|
67
|
+
|
|
68
|
+
else:
|
|
69
|
+
self.text_streamer_handler = DefaultStreamHandler(self.tokenizer)
|
|
70
|
+
self.model_type = "default"
|
|
71
|
+
|
|
72
|
+
if eric_search:
|
|
73
|
+
self.eric_search = eric_search
|
|
74
|
+
else:
|
|
75
|
+
self.eric_search = None
|
|
76
|
+
|
|
77
|
+
self.to_stream_tokens = []
|
|
78
|
+
|
|
79
|
+
def _get_streamer_prompt(
|
|
80
|
+
self, messages: Union[List[dict]], args: CHATCallArgs = CHATCallArgs()
|
|
81
|
+
):
|
|
82
|
+
mapped_messages = map_chat_roles(
|
|
83
|
+
messages=messages, model_name=self.model_name, model_type=self.model_type
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
87
|
+
mapped_messages, add_generation_prompt=True, tokenize=False
|
|
88
|
+
)
|
|
89
|
+
# always think.
|
|
90
|
+
if self.model_type == "gpt_oss":
|
|
91
|
+
prompt += "<|channel|>analysis<|message|>"
|
|
92
|
+
self.to_stream_tokens.append(self.text_streamer_handler.step("<|channel|>"))
|
|
93
|
+
self.to_stream_tokens.append(self.text_streamer_handler.step("analysis"))
|
|
94
|
+
self.to_stream_tokens.append(self.text_streamer_handler.step("<|message|>"))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
sampler = make_sampler(
|
|
98
|
+
temp=args.temp,
|
|
99
|
+
top_p=args.top_p,
|
|
100
|
+
top_k=args.top_k
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return sampler, prompt
|
|
104
|
+
|
|
105
|
+
def __call__(self, text: Union[List[dict], str], args: CHATCallArgs = CHATCallArgs()) -> CHATResult:
|
|
106
|
+
messages = format_messages(text)
|
|
107
|
+
|
|
108
|
+
if self.eric_search is not None:
|
|
109
|
+
search_query = create_search_prompt_chat(text)
|
|
110
|
+
data_result = self.eric_search(search_query, args=args.search_args)
|
|
111
|
+
|
|
112
|
+
if len(data_result):
|
|
113
|
+
top_result = data_result[0]
|
|
114
|
+
rag_content = formate_rag_content(
|
|
115
|
+
text=search_query, data_result=top_result
|
|
116
|
+
)
|
|
117
|
+
rag_message = formate_rag_message(rag_content=rag_content)
|
|
118
|
+
messages.insert(-1, rag_message)
|
|
119
|
+
|
|
120
|
+
sampler, prompt = self._get_streamer_prompt(messages=messages, args=args)
|
|
121
|
+
out = []
|
|
122
|
+
for resp in stream_generate(
|
|
123
|
+
self.model,
|
|
124
|
+
self.tokenizer,
|
|
125
|
+
prompt,
|
|
126
|
+
max_tokens=args.max_len,
|
|
127
|
+
sampler=sampler,
|
|
128
|
+
):
|
|
129
|
+
stream_result = self.text_streamer_handler.step(resp.text)
|
|
130
|
+
if stream_result:
|
|
131
|
+
if stream_result.marker == "text":
|
|
132
|
+
out.append(resp.text)
|
|
133
|
+
|
|
134
|
+
final_text = "".join(out).strip()
|
|
135
|
+
return CHATResult(text=final_text)
|
|
136
|
+
|
|
137
|
+
def stream(
|
|
138
|
+
self, text: Union[List[dict], str], args: CHATCallArgs = CHATCallArgs()
|
|
139
|
+
) -> Iterator[CHATStreamResult]:
|
|
140
|
+
messages = format_messages(text)
|
|
141
|
+
|
|
142
|
+
if self.eric_search is not None:
|
|
143
|
+
search_query = create_search_prompt_chat(text)
|
|
144
|
+
yield CHATStreamResult(text=search_query, marker="search", payload={})
|
|
145
|
+
data_result = self.eric_search(search_query, args=args.search_args)
|
|
146
|
+
if data_result:
|
|
147
|
+
top_result = data_result[0]
|
|
148
|
+
|
|
149
|
+
yield CHATStreamResult(
|
|
150
|
+
text=top_result.text,
|
|
151
|
+
marker="search_result",
|
|
152
|
+
payload={
|
|
153
|
+
"best_sentence": top_result.best_sentence,
|
|
154
|
+
"metadata": top_result.metadata,
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
rag_content = formate_rag_content(
|
|
158
|
+
text=search_query, data_result=top_result
|
|
159
|
+
)
|
|
160
|
+
rag_message = formate_rag_message(rag_content=rag_content)
|
|
161
|
+
messages.insert(-1, rag_message)
|
|
162
|
+
|
|
163
|
+
sampler, prompt = self._get_streamer_prompt(messages=messages, args=args)
|
|
164
|
+
|
|
165
|
+
while self.to_stream_tokens:
|
|
166
|
+
stream_result = self.to_stream_tokens.pop(0)
|
|
167
|
+
if stream_result:
|
|
168
|
+
yield stream_result
|
|
169
|
+
|
|
170
|
+
for resp in stream_generate(
|
|
171
|
+
self.model,
|
|
172
|
+
self.tokenizer,
|
|
173
|
+
prompt,
|
|
174
|
+
max_tokens=args.max_len,
|
|
175
|
+
sampler=sampler,
|
|
176
|
+
):
|
|
177
|
+
stream_result = self.text_streamer_handler.step(resp.text)
|
|
178
|
+
if stream_result:
|
|
179
|
+
yield stream_result
|
|
180
|
+
|
|
181
|
+
def save(self, path: str):
|
|
182
|
+
if self.model is None or self.tokenizer is None:
|
|
183
|
+
raise EricNoModelError("No model/tokenizer loaded")
|
|
184
|
+
|
|
185
|
+
os.makedirs(path, exist_ok=True)
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
save_model(model=self.model, save_path=path)
|
|
189
|
+
except Exception as e:
|
|
190
|
+
raise EricSaveError(f"Error saving MLX model to {path}: {e}")
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
self.tokenizer.save_pretrained(path)
|
|
194
|
+
except Exception as e:
|
|
195
|
+
raise EricSaveError(f"error saving MLX tokenizer {path}: {e}")
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
self.config.save_pretrained(path)
|
|
199
|
+
except Exception as e:
|
|
200
|
+
raise EricSaveError(f"Error saving config {path}: {e}")
|
|
201
|
+
|
|
202
|
+
def push(self, repo_id: str, private: bool = True):
|
|
203
|
+
api = HfApi()
|
|
204
|
+
try:
|
|
205
|
+
api.create_repo(repo_id, exist_ok=True, private=private)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
self.logger.warning(f"Could not crate repo {e}")
|
|
208
|
+
return
|
|
209
|
+
try:
|
|
210
|
+
has_readme = api.file_exists(repo_id, "README.md")
|
|
211
|
+
except Exception as e:
|
|
212
|
+
self.logger.warning(f"Could not info: {e}")
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
if not has_readme:
|
|
216
|
+
readme_text = self._get_readme(repo_id)
|
|
217
|
+
try:
|
|
218
|
+
self.logger.info("Pushing README...")
|
|
219
|
+
|
|
220
|
+
api.upload_file(
|
|
221
|
+
path_or_fileobj=readme_text.encode("utf-8"),
|
|
222
|
+
path_in_repo="README.md",
|
|
223
|
+
repo_id=repo_id,
|
|
224
|
+
repo_type="model",
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
# Don’t fail the whole push if README upload fails; just warn.
|
|
229
|
+
self.logger.warning(f"Error pushing README: {e}")
|
|
230
|
+
|
|
231
|
+
try:
|
|
232
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
233
|
+
self.save(tmpdir)
|
|
234
|
+
api.upload_large_folder(
|
|
235
|
+
repo_id, tmpdir, repo_type="model", private=True
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
except Exception as e:
|
|
239
|
+
raise EricPushError(f"Error uploading model and tokenizer: {e}")
|
|
240
|
+
|
|
241
|
+
def _get_readme(self, repo_id: str) -> str:
|
|
242
|
+
readme_text = textwrap.dedent(f"""\
|
|
243
|
+
---
|
|
244
|
+
tags:
|
|
245
|
+
- erictransformer
|
|
246
|
+
- eric-chat-mlx
|
|
247
|
+
- mlx
|
|
248
|
+
|
|
249
|
+
---
|
|
250
|
+
# {repo_id}
|
|
251
|
+
|
|
252
|
+
## Installation
|
|
253
|
+
|
|
254
|
+
On Mac
|
|
255
|
+
```
|
|
256
|
+
pip install mlx-lm erictransformer
|
|
257
|
+
```
|
|
258
|
+
|
|
259
|
+
## Usage
|
|
260
|
+
|
|
261
|
+
```python
|
|
262
|
+
from erictransformer import EricChatMLX, CHATCallArgs
|
|
263
|
+
|
|
264
|
+
eric_chat = EricChatMLX(model_name="{repo_id}")
|
|
265
|
+
|
|
266
|
+
text = 'Hello world'
|
|
267
|
+
|
|
268
|
+
result = eric_chat(text)
|
|
269
|
+
print(result.text)
|
|
270
|
+
|
|
271
|
+
# Streaming is also possible (see docs)
|
|
272
|
+
```
|
|
273
|
+
|
|
274
|
+
See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
|
|
275
|
+
|
|
276
|
+
""")
|
|
277
|
+
|
|
278
|
+
return readme_text
|