erictransformer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. erictransformer/__init__.py +44 -0
  2. erictransformer/args/__init__.py +7 -0
  3. erictransformer/args/eric_args.py +50 -0
  4. erictransformer/eric_tasks/__init__.py +47 -0
  5. erictransformer/eric_tasks/args/__init__.py +16 -0
  6. erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
  7. erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
  8. erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
  9. erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
  10. erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
  11. erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
  12. erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
  13. erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
  14. erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
  15. erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
  16. erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
  17. erictransformer/eric_tasks/chat_templates/convert.py +67 -0
  18. erictransformer/eric_tasks/eric_chat.py +369 -0
  19. erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
  20. erictransformer/eric_tasks/eric_generation.py +243 -0
  21. erictransformer/eric_tasks/eric_text_classification.py +231 -0
  22. erictransformer/eric_tasks/eric_text_to_text.py +283 -0
  23. erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
  24. erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
  25. erictransformer/eric_tasks/misc/__init__.py +11 -0
  26. erictransformer/eric_tasks/misc/call_utils.py +69 -0
  27. erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
  28. erictransformer/eric_tasks/misc/rag.py +17 -0
  29. erictransformer/eric_tasks/results/__init__.py +6 -0
  30. erictransformer/eric_tasks/results/call_results.py +30 -0
  31. erictransformer/eric_tasks/tok/__init__.py +0 -0
  32. erictransformer/eric_tasks/tok/tok_functions.py +118 -0
  33. erictransformer/eric_tracker/__init__.py +1 -0
  34. erictransformer/eric_tracker/eric_tracker.py +256 -0
  35. erictransformer/eric_tracker/save_plot.py +422 -0
  36. erictransformer/eric_transformer.py +534 -0
  37. erictransformer/eval_models/__init__.py +1 -0
  38. erictransformer/eval_models/eval_model.py +75 -0
  39. erictransformer/exceptions/__init__.py +19 -0
  40. erictransformer/exceptions/eric_exceptions.py +74 -0
  41. erictransformer/loops/__init__.py +2 -0
  42. erictransformer/loops/eval_loop.py +111 -0
  43. erictransformer/loops/train_loop.py +310 -0
  44. erictransformer/utils/__init__.py +21 -0
  45. erictransformer/utils/init/__init__.py +5 -0
  46. erictransformer/utils/init/get_components.py +204 -0
  47. erictransformer/utils/init/get_device.py +22 -0
  48. erictransformer/utils/init/get_logger.py +15 -0
  49. erictransformer/utils/load_from_repo_or_path.py +14 -0
  50. erictransformer/utils/test/__init__.py +1 -0
  51. erictransformer/utils/test/debug_hook.py +20 -0
  52. erictransformer/utils/timer/__init__.py +1 -0
  53. erictransformer/utils/timer/eric_timer.py +145 -0
  54. erictransformer/utils/tok_data/__init__.py +8 -0
  55. erictransformer/utils/tok_data/num_proc.py +15 -0
  56. erictransformer/utils/tok_data/save_tok_data.py +36 -0
  57. erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
  58. erictransformer/utils/tok_data/tok_helpers.py +79 -0
  59. erictransformer/utils/train/__init__.py +6 -0
  60. erictransformer/utils/train/confirm_optimizer.py +18 -0
  61. erictransformer/utils/train/create_dir.py +72 -0
  62. erictransformer/utils/train/get_num_training_steps.py +15 -0
  63. erictransformer/utils/train/get_precision.py +22 -0
  64. erictransformer/utils/train/get_tok_data.py +105 -0
  65. erictransformer/utils/train/resume.py +62 -0
  66. erictransformer/validator/__init__.py +11 -0
  67. erictransformer/validator/eric/__init__.py +2 -0
  68. erictransformer/validator/eric/eval_validator.py +75 -0
  69. erictransformer/validator/eric/train_validator.py +143 -0
  70. erictransformer/validator/eric_validator.py +10 -0
  71. erictransformer/validator/tasks/__init__.py +5 -0
  72. erictransformer/validator/tasks/chat_validator.py +28 -0
  73. erictransformer/validator/tasks/gen_validator.py +28 -0
  74. erictransformer/validator/tasks/task_validator.py +54 -0
  75. erictransformer/validator/tasks/tc_validator.py +45 -0
  76. erictransformer/validator/tasks/tt_validator.py +28 -0
  77. erictransformer/validator/tok/__init__.py +1 -0
  78. erictransformer/validator/tok/tok_validator.py +23 -0
  79. erictransformer-0.0.1.dist-info/METADATA +72 -0
  80. erictransformer-0.0.1.dist-info/RECORD +83 -0
  81. erictransformer-0.0.1.dist-info/WHEEL +5 -0
  82. erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
  83. erictransformer-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,369 @@
1
+ import textwrap
2
+ import threading
3
+ import traceback
4
+ from typing import Iterator, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from datasets import Dataset
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ GenerationConfig,
11
+ PretrainedConfig,
12
+ PreTrainedModel,
13
+ PreTrainedTokenizerBase,
14
+ TextIteratorStreamer,
15
+ default_data_collator,
16
+ )
17
+ from ericsearch import EricSearch
18
+
19
+
20
+ from erictransformer.eval_models import EvalModel
21
+ from erictransformer.args import EricTrainArgs, EricEvalArgs
22
+ from erictransformer.eric_tasks.args import (
23
+ CHATCallArgs,
24
+ CHATTokArgs,
25
+ )
26
+ from erictransformer.eric_tasks.chat_stream_handlers import (
27
+ CHATStreamResult,
28
+ DefaultStreamHandler,
29
+ GPTOSSSMHandler,
30
+ SmolStreamHandler,
31
+ )
32
+ from erictransformer.eric_tasks.chat_templates import map_chat_roles
33
+ from erictransformer.eric_tasks.misc import (
34
+ create_search_prompt_chat,
35
+ format_messages,
36
+ formate_rag_content,
37
+ formate_rag_message,
38
+ generate_gen_kwargs,
39
+ get_pad_eos,
40
+ )
41
+ from erictransformer.eric_tasks.results import CHATResult
42
+ from erictransformer.eric_tasks.tok.tok_functions import (
43
+ get_max_in_len,
44
+ tokenize_chat_template,
45
+ )
46
+ from erictransformer.eric_transformer import EricTransformer, EricTransformerArgs
47
+ from erictransformer.loops import EvalResult
48
+ from erictransformer.utils import get_model_components
49
+ from erictransformer.validator import CHATValidator
50
+
51
+
52
+ class EricChat(EricTransformer):
53
+ def __init__(
54
+ self,
55
+ model_name: Union[str, PreTrainedModel, None] = "openai/gpt-oss-20b",
56
+ *,
57
+ trust_remote_code: bool = False,
58
+ tokenizer: Union[str, PreTrainedTokenizerBase] = None,
59
+ eric_search: Optional[EricSearch] = None,
60
+ ):
61
+ model_class = AutoModelForCausalLM
62
+
63
+ eric_args = EricTransformerArgs(
64
+ model_name=model_name,
65
+ model_class=model_class,
66
+ trust_remote_code=trust_remote_code,
67
+ tokenizer=tokenizer
68
+ )
69
+
70
+ super().__init__(eric_args)
71
+
72
+ self.task_validator = CHATValidator(
73
+ model_name=model_name,
74
+ trust_remote_code=trust_remote_code,
75
+ tokenizer=tokenizer,
76
+ logger=self.logger,
77
+ )
78
+
79
+ if not getattr(self.tokenizer, "chat_template", None):
80
+ raise ValueError("The tokenizer must include a chat template")
81
+
82
+ self._data_collator = default_data_collator
83
+
84
+ self.logger.info("Using tokenizer's built-in chat template.")
85
+ if model_name:
86
+ self.config = self.model.config
87
+
88
+ if self.model is not None:
89
+ self.pad_token_id, self.eos_token_id = get_pad_eos(
90
+ self.tokenizer, self.model
91
+ )
92
+
93
+ if (
94
+ model_name is not None
95
+ ): # we don't need to initialize these if a model is not provided.
96
+ if self.config.model_type == "smollm3":
97
+ self.text_streamer_handler = SmolStreamHandler(self.tokenizer, self.logger)
98
+ self.model_type = "smollm3"
99
+ elif self.config.model_type == "gpt_oss":
100
+ self.text_streamer_handler = GPTOSSSMHandler(self.tokenizer, self.logger)
101
+ self.model_type = "gpt_oss"
102
+ else:
103
+ self.text_streamer_handler = DefaultStreamHandler(self.tokenizer)
104
+ self.model_type = "default"
105
+
106
+ if eric_search:
107
+ self.eric_search = eric_search
108
+ else:
109
+ self.eric_search = None
110
+
111
+ if self.model:
112
+ self._prep_model()
113
+
114
+ self.to_stream_tokens = []
115
+
116
+ def _get_call_thread_streamer(
117
+ self, messages: list[dict], args: CHATCallArgs = CHATCallArgs()
118
+ ):
119
+ mapped_messages = map_chat_roles(
120
+ messages=messages,
121
+ model_name=self.eric_args.model_name,
122
+ model_type=self.model_type,
123
+ )
124
+ if self.model_type != "gpt_oss":
125
+
126
+ input_ids = self.tokenizer.apply_chat_template(
127
+ mapped_messages, add_generation_prompt=True, return_tensors="pt"
128
+ )
129
+ else:
130
+ prompt = self.tokenizer.apply_chat_template(
131
+ mapped_messages, add_generation_prompt=True, tokenize=False
132
+ )
133
+
134
+ prompt += "<|channel|>analysis<|message|>"
135
+ self.to_stream_tokens.append(self.text_streamer_handler.step("<|channel|>"))
136
+ self.to_stream_tokens.append(self.text_streamer_handler.step("analysis"))
137
+ self.to_stream_tokens.append(self.text_streamer_handler.step("<|message|>"))
138
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
139
+
140
+ if input_ids.ndim == 1:
141
+ input_ids = input_ids.unsqueeze(0)
142
+
143
+ input_ids = input_ids.to(self.model.device)
144
+
145
+ attention_mask = torch.ones_like(
146
+ input_ids, dtype=torch.long, device=self.model.device
147
+ )
148
+
149
+ gen_streamer = TextIteratorStreamer(
150
+ self.tokenizer, skip_prompt=True, skip_special_tokens=False
151
+ )
152
+
153
+ gen_kwargs = generate_gen_kwargs(
154
+ input_ids=input_ids,
155
+ attention_mask=attention_mask,
156
+ streamer=gen_streamer,
157
+ args=args,
158
+ eos_token_id=self.eos_token_id,
159
+ pad_token_id=self.pad_token_id,
160
+ )
161
+
162
+ def thread_fn():
163
+ try:
164
+ self.model.generate(**gen_kwargs)
165
+ except BaseException:
166
+ # Bare except here is OK because we are in background thread
167
+ # and therefore don't need to worry about
168
+ # muting BaseExceptions such as KeyboardInterrupt
169
+
170
+ err_str = traceback.format_exc()
171
+ self.logger.error(f"Error in generate thread: {err_str}")
172
+ gen_streamer.end()
173
+
174
+ gen_thread = threading.Thread(target=thread_fn)
175
+ return gen_thread, gen_streamer
176
+
177
+ def __call__(
178
+ self, text: Union[str, List[dict]], args: CHATCallArgs = CHATCallArgs()
179
+ ) -> CHATResult:
180
+ messages = format_messages(text)
181
+
182
+ if self.eric_search is not None:
183
+ search_query = create_search_prompt_chat(text)
184
+ data_result = self.eric_search(search_query, args=args.search_args)
185
+
186
+ if data_result:
187
+ top_result = data_result[0]
188
+ rag_content = formate_rag_content(text=text, data_result=top_result)
189
+ rag_message = formate_rag_message(rag_content=rag_content)
190
+ messages.insert(-1, rag_message)
191
+
192
+ self._get_model_ready_inference()
193
+ gen_thread, gen_streamer = self._get_call_thread_streamer(messages, args)
194
+ gen_thread.start()
195
+ out_text = []
196
+ try:
197
+ for text in gen_streamer:
198
+ tokens = self.tokenizer.encode(text)
199
+ for token in tokens:
200
+ token_string = self.tokenizer.decode(token)
201
+ stream_result = self.text_streamer_handler.step(token_string)
202
+ if stream_result:
203
+ if stream_result.marker == "text":
204
+ out_text.append(stream_result.text)
205
+ finally:
206
+ gen_thread.join()
207
+
208
+ final_text = "".join(out_text)
209
+
210
+ return CHATResult(text=final_text)
211
+
212
+ def stream(
213
+ self, text: Union[str, List[dict]], args: CHATCallArgs = CHATCallArgs()
214
+ ) -> Iterator[CHATStreamResult]:
215
+ messages = format_messages(text)
216
+
217
+ if self.eric_search is not None:
218
+ search_query = create_search_prompt_chat(text)
219
+ yield CHATStreamResult(
220
+ text="", marker="search", payload={"query": search_query}
221
+ )
222
+
223
+ data_result = self.eric_search(search_query, args=args.search_args)
224
+
225
+ if data_result:
226
+ top_result = data_result[0]
227
+ yield CHATStreamResult(
228
+ text="",
229
+ marker="search_result",
230
+ payload={
231
+ "text": top_result.text,
232
+ "best_sentence": top_result.best_sentence,
233
+ "metadata": top_result.metadata,
234
+ },
235
+ )
236
+
237
+ rag_content = formate_rag_content(text=text, data_result=top_result)
238
+ rag_message = formate_rag_message(rag_content=rag_content)
239
+ messages.insert(-1, rag_message)
240
+
241
+ self._get_model_ready_inference()
242
+
243
+ gen_thread, gen_streamer = self._get_call_thread_streamer(messages, args)
244
+
245
+ while self.to_stream_tokens:
246
+ stream_result = self.to_stream_tokens.pop(0)
247
+ if stream_result:
248
+ yield stream_result
249
+
250
+ gen_thread.start()
251
+ try:
252
+ for text in gen_streamer:
253
+ tokens = self.tokenizer.encode(text)
254
+ for token in tokens:
255
+ token_string = self.tokenizer.decode(token)
256
+ stream_result = self.text_streamer_handler.step(token_string)
257
+ if stream_result:
258
+ yield stream_result
259
+
260
+ finally:
261
+ gen_thread.join()
262
+
263
+ def _tok_function(
264
+ self,
265
+ raw_dataset,
266
+ args: CHATTokArgs = CHATTokArgs(),
267
+ file_type: str = "jsonl",
268
+ procs: Optional[int] = None,
269
+ ) -> Dataset:
270
+ max_in_len = get_max_in_len(args.max_len, self.tokenizer)
271
+
272
+ return tokenize_chat_template(
273
+ tokenizer=self.tokenizer,
274
+ dataset=raw_dataset,
275
+ max_len=max_in_len,
276
+ bs=args.bs,
277
+ procs=procs,
278
+ )
279
+
280
+ def train(
281
+ self,
282
+ train_path: str = "",
283
+ args: EricTrainArgs = EricTrainArgs(),
284
+ eval_path: str = "",
285
+ *,
286
+ resume_path: str = "",
287
+ ):
288
+ return super(EricChat, self).train(
289
+ train_path, args, eval_path, resume_path=resume_path
290
+ )
291
+
292
+ def eval(
293
+ self, eval_path: str = "", args: EricEvalArgs = EricEvalArgs()
294
+ ) -> EvalResult:
295
+ return super(EricChat, self).eval(eval_path=eval_path, args=args)
296
+
297
+ def tok(self, path: str, out_dir: str, args: CHATTokArgs = CHATTokArgs()):
298
+ return super(EricChat, self).tok(
299
+ path=path, out_dir=out_dir, args=args
300
+ )
301
+
302
+ def _load_model_components(
303
+ self,
304
+ ) -> Tuple[PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel]:
305
+ return get_model_components(
306
+ model_name_path=self.eric_args.model_name,
307
+ trust_remote_code=self.eric_args.trust_remote_code,
308
+ model_class=self.eric_args.model_class,
309
+ tokenizer_path=self.eric_args.tokenizer,
310
+ precision=self.precision_type,
311
+ )
312
+
313
+ def _format_tokenized_example(self, example: dict) -> dict:
314
+ return {
315
+ "input_ids": example["input_ids"],
316
+ "attention_mask": example["attention_mask"],
317
+ "labels": example["labels"],
318
+ }
319
+
320
+ def _get_default_eval_models(self) -> List[EvalModel]:
321
+ return []
322
+
323
+ def _prep_model(self):
324
+ args = CHATCallArgs()
325
+ generation_config = GenerationConfig.from_model_config(self.model.config)
326
+ generation_config.num_beams = 1
327
+ generation_config.early_stopping = False
328
+ generation_config.do_sample = True
329
+ generation_config.min_len = args.min_len
330
+ generation_config.max_len = args.max_len
331
+ generation_config.temp = args.temp
332
+ generation_config.top_p = args.top_p
333
+ self.model.generation_config = generation_config
334
+
335
+ def _get_readme(self, repo_id: str) -> str:
336
+ readme_text = textwrap.dedent(f"""\
337
+ ---
338
+ tags:
339
+ - erictransformer
340
+ - eric-chat
341
+ ---
342
+ # {repo_id}
343
+
344
+ ## Installation:
345
+
346
+ ```
347
+ pip install erictransformer
348
+ ```
349
+
350
+ ## Usage
351
+
352
+ ```python
353
+ from erictransformer import EricChat, CHATCallArgs
354
+
355
+ eric_chat = EricChat(model_name="{repo_id}")
356
+
357
+ text = 'Hello world'
358
+
359
+ result = eric_chat(text)
360
+ print(result.text)
361
+
362
+ # Streaming is also possible (see docs)
363
+ ```
364
+
365
+ See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
366
+
367
+ """)
368
+
369
+ return readme_text
@@ -0,0 +1,278 @@
1
+ import os
2
+ import tempfile
3
+ import textwrap
4
+ from typing import Iterator, List, Optional, Union
5
+
6
+ from huggingface_hub import HfApi
7
+ from transformers import AutoConfig
8
+ from ericsearch import EricSearch
9
+
10
+
11
+ try:
12
+ from mlx_lm import load, stream_generate
13
+ from mlx_lm.sample_utils import make_sampler
14
+ from mlx_lm.utils import save_model
15
+ except ImportError as err:
16
+ raise ImportError("""
17
+ Failed to import MLX. If you are using a Mac,
18
+ try `pip install mlx-lm`. If you have a CUDA-compatible device,
19
+ try `pip install mlx[cuda] mlx-lm`. Otherwise, try `pip install mlx[cpu] mlx-lm`
20
+ """) from err
21
+
22
+ from erictransformer.exceptions import (
23
+ EricNoModelError,
24
+ EricPushError,
25
+ EricSaveError,
26
+ )
27
+ from erictransformer.eric_tasks.args import CHATCallArgs
28
+ from erictransformer.eric_tasks.chat_stream_handlers import (
29
+ CHATStreamResult,
30
+ DefaultStreamHandler,
31
+ GPTOSSSMHandler,
32
+ SmolStreamHandler,
33
+ )
34
+ from erictransformer.eric_tasks.chat_templates import map_chat_roles
35
+ from erictransformer.eric_tasks.misc import (
36
+ create_search_prompt_chat,
37
+ format_messages,
38
+ formate_rag_content,
39
+ formate_rag_message,
40
+ )
41
+ from erictransformer.eric_tasks.results import CHATResult
42
+ from erictransformer.utils import et_get_logger
43
+
44
+
45
+ class EricChatMLX:
46
+ def __init__(
47
+ self,
48
+ model_name: str = "mlx-community/SmolLM3-3B-4bit",
49
+ *,
50
+ eric_search: Optional[EricSearch] = None,
51
+ ):
52
+ self.model_name = model_name
53
+ self.model, self.tokenizer = load(model_name)
54
+ self.logger = et_get_logger()
55
+
56
+ if not getattr(self.tokenizer, "chat_template", None):
57
+ raise ValueError("The tokenizer must include a chat template")
58
+
59
+ self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=False)
60
+
61
+ if self.config.model_type == "smollm3":
62
+ self.text_streamer_handler = SmolStreamHandler(self.tokenizer, self.logger)
63
+ self.model_type = "smollm3"
64
+ elif self.config.model_type == "gpt_oss":
65
+ self.text_streamer_handler = GPTOSSSMHandler(self.tokenizer, self.logger)
66
+ self.model_type = "gpt_oss"
67
+
68
+ else:
69
+ self.text_streamer_handler = DefaultStreamHandler(self.tokenizer)
70
+ self.model_type = "default"
71
+
72
+ if eric_search:
73
+ self.eric_search = eric_search
74
+ else:
75
+ self.eric_search = None
76
+
77
+ self.to_stream_tokens = []
78
+
79
+ def _get_streamer_prompt(
80
+ self, messages: Union[List[dict]], args: CHATCallArgs = CHATCallArgs()
81
+ ):
82
+ mapped_messages = map_chat_roles(
83
+ messages=messages, model_name=self.model_name, model_type=self.model_type
84
+ )
85
+
86
+ prompt = self.tokenizer.apply_chat_template(
87
+ mapped_messages, add_generation_prompt=True, tokenize=False
88
+ )
89
+ # always think.
90
+ if self.model_type == "gpt_oss":
91
+ prompt += "<|channel|>analysis<|message|>"
92
+ self.to_stream_tokens.append(self.text_streamer_handler.step("<|channel|>"))
93
+ self.to_stream_tokens.append(self.text_streamer_handler.step("analysis"))
94
+ self.to_stream_tokens.append(self.text_streamer_handler.step("<|message|>"))
95
+
96
+
97
+ sampler = make_sampler(
98
+ temp=args.temp,
99
+ top_p=args.top_p,
100
+ top_k=args.top_k
101
+ )
102
+
103
+ return sampler, prompt
104
+
105
+ def __call__(self, text: Union[List[dict], str], args: CHATCallArgs = CHATCallArgs()) -> CHATResult:
106
+ messages = format_messages(text)
107
+
108
+ if self.eric_search is not None:
109
+ search_query = create_search_prompt_chat(text)
110
+ data_result = self.eric_search(search_query, args=args.search_args)
111
+
112
+ if len(data_result):
113
+ top_result = data_result[0]
114
+ rag_content = formate_rag_content(
115
+ text=search_query, data_result=top_result
116
+ )
117
+ rag_message = formate_rag_message(rag_content=rag_content)
118
+ messages.insert(-1, rag_message)
119
+
120
+ sampler, prompt = self._get_streamer_prompt(messages=messages, args=args)
121
+ out = []
122
+ for resp in stream_generate(
123
+ self.model,
124
+ self.tokenizer,
125
+ prompt,
126
+ max_tokens=args.max_len,
127
+ sampler=sampler,
128
+ ):
129
+ stream_result = self.text_streamer_handler.step(resp.text)
130
+ if stream_result:
131
+ if stream_result.marker == "text":
132
+ out.append(resp.text)
133
+
134
+ final_text = "".join(out).strip()
135
+ return CHATResult(text=final_text)
136
+
137
+ def stream(
138
+ self, text: Union[List[dict], str], args: CHATCallArgs = CHATCallArgs()
139
+ ) -> Iterator[CHATStreamResult]:
140
+ messages = format_messages(text)
141
+
142
+ if self.eric_search is not None:
143
+ search_query = create_search_prompt_chat(text)
144
+ yield CHATStreamResult(text=search_query, marker="search", payload={})
145
+ data_result = self.eric_search(search_query, args=args.search_args)
146
+ if data_result:
147
+ top_result = data_result[0]
148
+
149
+ yield CHATStreamResult(
150
+ text=top_result.text,
151
+ marker="search_result",
152
+ payload={
153
+ "best_sentence": top_result.best_sentence,
154
+ "metadata": top_result.metadata,
155
+ },
156
+ )
157
+ rag_content = formate_rag_content(
158
+ text=search_query, data_result=top_result
159
+ )
160
+ rag_message = formate_rag_message(rag_content=rag_content)
161
+ messages.insert(-1, rag_message)
162
+
163
+ sampler, prompt = self._get_streamer_prompt(messages=messages, args=args)
164
+
165
+ while self.to_stream_tokens:
166
+ stream_result = self.to_stream_tokens.pop(0)
167
+ if stream_result:
168
+ yield stream_result
169
+
170
+ for resp in stream_generate(
171
+ self.model,
172
+ self.tokenizer,
173
+ prompt,
174
+ max_tokens=args.max_len,
175
+ sampler=sampler,
176
+ ):
177
+ stream_result = self.text_streamer_handler.step(resp.text)
178
+ if stream_result:
179
+ yield stream_result
180
+
181
+ def save(self, path: str):
182
+ if self.model is None or self.tokenizer is None:
183
+ raise EricNoModelError("No model/tokenizer loaded")
184
+
185
+ os.makedirs(path, exist_ok=True)
186
+
187
+ try:
188
+ save_model(model=self.model, save_path=path)
189
+ except Exception as e:
190
+ raise EricSaveError(f"Error saving MLX model to {path}: {e}")
191
+
192
+ try:
193
+ self.tokenizer.save_pretrained(path)
194
+ except Exception as e:
195
+ raise EricSaveError(f"error saving MLX tokenizer {path}: {e}")
196
+
197
+ try:
198
+ self.config.save_pretrained(path)
199
+ except Exception as e:
200
+ raise EricSaveError(f"Error saving config {path}: {e}")
201
+
202
+ def push(self, repo_id: str, private: bool = True):
203
+ api = HfApi()
204
+ try:
205
+ api.create_repo(repo_id, exist_ok=True, private=private)
206
+ except Exception as e:
207
+ self.logger.warning(f"Could not crate repo {e}")
208
+ return
209
+ try:
210
+ has_readme = api.file_exists(repo_id, "README.md")
211
+ except Exception as e:
212
+ self.logger.warning(f"Could not info: {e}")
213
+ return
214
+
215
+ if not has_readme:
216
+ readme_text = self._get_readme(repo_id)
217
+ try:
218
+ self.logger.info("Pushing README...")
219
+
220
+ api.upload_file(
221
+ path_or_fileobj=readme_text.encode("utf-8"),
222
+ path_in_repo="README.md",
223
+ repo_id=repo_id,
224
+ repo_type="model",
225
+ )
226
+
227
+ except Exception as e:
228
+ # Don’t fail the whole push if README upload fails; just warn.
229
+ self.logger.warning(f"Error pushing README: {e}")
230
+
231
+ try:
232
+ with tempfile.TemporaryDirectory() as tmpdir:
233
+ self.save(tmpdir)
234
+ api.upload_large_folder(
235
+ repo_id, tmpdir, repo_type="model", private=True
236
+ )
237
+
238
+ except Exception as e:
239
+ raise EricPushError(f"Error uploading model and tokenizer: {e}")
240
+
241
+ def _get_readme(self, repo_id: str) -> str:
242
+ readme_text = textwrap.dedent(f"""\
243
+ ---
244
+ tags:
245
+ - erictransformer
246
+ - eric-chat-mlx
247
+ - mlx
248
+
249
+ ---
250
+ # {repo_id}
251
+
252
+ ## Installation
253
+
254
+ On Mac
255
+ ```
256
+ pip install mlx-lm erictransformer
257
+ ```
258
+
259
+ ## Usage
260
+
261
+ ```python
262
+ from erictransformer import EricChatMLX, CHATCallArgs
263
+
264
+ eric_chat = EricChatMLX(model_name="{repo_id}")
265
+
266
+ text = 'Hello world'
267
+
268
+ result = eric_chat(text)
269
+ print(result.text)
270
+
271
+ # Streaming is also possible (see docs)
272
+ ```
273
+
274
+ See Eric Transformer's [GitHub](https://github.com/ericfillion/erictransformer) for more information.
275
+
276
+ """)
277
+
278
+ return readme_text