erictransformer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. erictransformer/__init__.py +44 -0
  2. erictransformer/args/__init__.py +7 -0
  3. erictransformer/args/eric_args.py +50 -0
  4. erictransformer/eric_tasks/__init__.py +47 -0
  5. erictransformer/eric_tasks/args/__init__.py +16 -0
  6. erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
  7. erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
  8. erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
  9. erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
  10. erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
  11. erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
  12. erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
  13. erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
  14. erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
  15. erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
  16. erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
  17. erictransformer/eric_tasks/chat_templates/convert.py +67 -0
  18. erictransformer/eric_tasks/eric_chat.py +369 -0
  19. erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
  20. erictransformer/eric_tasks/eric_generation.py +243 -0
  21. erictransformer/eric_tasks/eric_text_classification.py +231 -0
  22. erictransformer/eric_tasks/eric_text_to_text.py +283 -0
  23. erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
  24. erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
  25. erictransformer/eric_tasks/misc/__init__.py +11 -0
  26. erictransformer/eric_tasks/misc/call_utils.py +69 -0
  27. erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
  28. erictransformer/eric_tasks/misc/rag.py +17 -0
  29. erictransformer/eric_tasks/results/__init__.py +6 -0
  30. erictransformer/eric_tasks/results/call_results.py +30 -0
  31. erictransformer/eric_tasks/tok/__init__.py +0 -0
  32. erictransformer/eric_tasks/tok/tok_functions.py +118 -0
  33. erictransformer/eric_tracker/__init__.py +1 -0
  34. erictransformer/eric_tracker/eric_tracker.py +256 -0
  35. erictransformer/eric_tracker/save_plot.py +422 -0
  36. erictransformer/eric_transformer.py +534 -0
  37. erictransformer/eval_models/__init__.py +1 -0
  38. erictransformer/eval_models/eval_model.py +75 -0
  39. erictransformer/exceptions/__init__.py +19 -0
  40. erictransformer/exceptions/eric_exceptions.py +74 -0
  41. erictransformer/loops/__init__.py +2 -0
  42. erictransformer/loops/eval_loop.py +111 -0
  43. erictransformer/loops/train_loop.py +310 -0
  44. erictransformer/utils/__init__.py +21 -0
  45. erictransformer/utils/init/__init__.py +5 -0
  46. erictransformer/utils/init/get_components.py +204 -0
  47. erictransformer/utils/init/get_device.py +22 -0
  48. erictransformer/utils/init/get_logger.py +15 -0
  49. erictransformer/utils/load_from_repo_or_path.py +14 -0
  50. erictransformer/utils/test/__init__.py +1 -0
  51. erictransformer/utils/test/debug_hook.py +20 -0
  52. erictransformer/utils/timer/__init__.py +1 -0
  53. erictransformer/utils/timer/eric_timer.py +145 -0
  54. erictransformer/utils/tok_data/__init__.py +8 -0
  55. erictransformer/utils/tok_data/num_proc.py +15 -0
  56. erictransformer/utils/tok_data/save_tok_data.py +36 -0
  57. erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
  58. erictransformer/utils/tok_data/tok_helpers.py +79 -0
  59. erictransformer/utils/train/__init__.py +6 -0
  60. erictransformer/utils/train/confirm_optimizer.py +18 -0
  61. erictransformer/utils/train/create_dir.py +72 -0
  62. erictransformer/utils/train/get_num_training_steps.py +15 -0
  63. erictransformer/utils/train/get_precision.py +22 -0
  64. erictransformer/utils/train/get_tok_data.py +105 -0
  65. erictransformer/utils/train/resume.py +62 -0
  66. erictransformer/validator/__init__.py +11 -0
  67. erictransformer/validator/eric/__init__.py +2 -0
  68. erictransformer/validator/eric/eval_validator.py +75 -0
  69. erictransformer/validator/eric/train_validator.py +143 -0
  70. erictransformer/validator/eric_validator.py +10 -0
  71. erictransformer/validator/tasks/__init__.py +5 -0
  72. erictransformer/validator/tasks/chat_validator.py +28 -0
  73. erictransformer/validator/tasks/gen_validator.py +28 -0
  74. erictransformer/validator/tasks/task_validator.py +54 -0
  75. erictransformer/validator/tasks/tc_validator.py +45 -0
  76. erictransformer/validator/tasks/tt_validator.py +28 -0
  77. erictransformer/validator/tok/__init__.py +1 -0
  78. erictransformer/validator/tok/tok_validator.py +23 -0
  79. erictransformer-0.0.1.dist-info/METADATA +72 -0
  80. erictransformer-0.0.1.dist-info/RECORD +83 -0
  81. erictransformer-0.0.1.dist-info/WHEEL +5 -0
  82. erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
  83. erictransformer-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,44 @@
1
+ from erictransformer.eric_tasks import (
2
+ CHATCallArgs,
3
+ CHATResult,
4
+ CHATStreamResult,
5
+ CHATTokArgs,
6
+ GENCallArgs,
7
+ GENResult,
8
+ GENTokArgs,
9
+ EricChat,
10
+ EricGeneration,
11
+ EricTextClassification,
12
+ EricTextToText,
13
+ TCCallArgs,
14
+ TCResult,
15
+ TCTokArgs,
16
+ TTCallArgs,
17
+ TTResult,
18
+ TTTokArgs
19
+ )
20
+
21
+ from erictransformer.args import EricTrainArgs, EricEvalArgs
22
+ from erictransformer.eric_transformer import EricTransformer
23
+ from erictransformer.loops import EvalResult, TrainResult
24
+
25
+ __all__ = []
26
+
27
+ try:
28
+ from mlx_lm import load, stream_generate
29
+ from mlx_lm.sample_utils import make_sampler
30
+ from mlx_lm.utils import save_model
31
+ mlx_enabled = True
32
+ except ImportError as err:
33
+ mlx_enabled = False
34
+
35
+ if mlx_enabled:
36
+ try:
37
+ from .eric_tasks.eric_chat_mlx import EricChatMLX
38
+ except Exception:
39
+ # Optional: log or ignore
40
+ pass
41
+ else:
42
+ __all__.append("EricChatMLX")
43
+
44
+ name = "erictransformer"
@@ -0,0 +1,7 @@
1
+ from erictransformer.args.eric_args import (
2
+ CallResult,
3
+ CallArgs,
4
+ EricEvalArgs,
5
+ TokArgs,
6
+ EricTrainArgs,
7
+ )
@@ -0,0 +1,50 @@
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass(kw_only=True)
4
+ class EricTrainArgs:
5
+ # Learning parameters
6
+ lr: float = 2e-5
7
+ bs: int = 1
8
+ eval_bs: int = 0 # when 0 uses bs
9
+ epochs: int = 1
10
+ gas: int = 1
11
+ optim: str = "adamw" # options adamw and sgd
12
+ lr_sched: str = "constant" # other option: warmup_then_decay
13
+
14
+ # Action steps
15
+ eval_steps: int = 256 # if 0 no evaluating will be done
16
+ log_steps: int = 256 # if 0 no logging will be done
17
+ checkpoint_steps: int = -1 # if -1 no checkpointing will be done
18
+ save_best: bool = False # saves the model with the lowest eval loss
19
+
20
+ # Misc
21
+ out_dir: str = "eric_transformer/"
22
+ run_name: str = ""
23
+ seed: int = 42
24
+
25
+
26
+ @dataclass(kw_only=True)
27
+ class TokArgs:
28
+ bs: int = 1024
29
+ max_cases: int = -1
30
+ shards: int = 1
31
+ procs: int = -1
32
+
33
+
34
+
35
+ @dataclass(kw_only=True)
36
+ class EricEvalArgs:
37
+ bs: int = 1
38
+ out_dir: str = "eric_transformer/"
39
+ run_name: str = ""
40
+ seed: int = 42
41
+
42
+
43
+ @dataclass(kw_only=True)
44
+ class CallArgs:
45
+ pass
46
+
47
+
48
+ @dataclass(kw_only=True)
49
+ class CallResult:
50
+ pass
@@ -0,0 +1,47 @@
1
+ from erictransformer.eric_tasks.args import (
2
+ CHATCallArgs,
3
+ CHATTokArgs,
4
+ GENCallArgs,
5
+ GENTokArgs,
6
+ TCCallArgs,
7
+ TCTokArgs,
8
+ TTCallArgs,
9
+ TTTokArgs,
10
+ )
11
+ from erictransformer.eric_tasks.chat_stream_handlers import CHATStreamResult
12
+ from erictransformer.eric_tasks.eric_chat import EricChat
13
+ from erictransformer.eric_tasks.eric_generation import EricGeneration
14
+ from erictransformer.eric_tasks.eric_text_classification import (
15
+ EricTextClassification,
16
+ )
17
+ from erictransformer.eric_tasks.eric_text_to_text import (
18
+ EricTextToText,
19
+ TTStreamResult,
20
+ )
21
+ from erictransformer.eric_tasks.results import (
22
+ CHATResult,
23
+ GENResult,
24
+ TCResult,
25
+ TTResult
26
+ )
27
+
28
+ __all__ = []
29
+
30
+
31
+ try:
32
+ from mlx_lm import load, stream_generate
33
+ from mlx_lm.sample_utils import make_sampler
34
+ from mlx_lm.utils import save_model
35
+ mlx_enabled = True
36
+ except ImportError as err:
37
+ mlx_enabled = False
38
+
39
+
40
+ if mlx_enabled:
41
+ try:
42
+ from .eric_chat_mlx import EricChatMLX
43
+ except Exception:
44
+ # Optional: log or ignore
45
+ pass
46
+ else:
47
+ __all__.append("EricChatMLX")
@@ -0,0 +1,16 @@
1
+ from erictransformer.eric_tasks.args.eric_chat_args import (
2
+ CHATCallArgs,
3
+ CHATTokArgs,
4
+ )
5
+ from erictransformer.eric_tasks.args.eric_generation_args import (
6
+ GENCallArgs,
7
+ GENTokArgs,
8
+ )
9
+ from erictransformer.eric_tasks.args.eric_text_classification_args import (
10
+ TCCallArgs,
11
+ TCTokArgs,
12
+ )
13
+ from erictransformer.eric_tasks.args.eric_text_to_text_args import (
14
+ TTCallArgs,
15
+ TTTokArgs,
16
+ )
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass, field
2
+ from ericsearch import SearchCallArgs
3
+
4
+ from erictransformer.args import CallArgs, TokArgs
5
+
6
+
7
+ @dataclass(kw_only=True)
8
+ class CHATCallArgs(CallArgs):
9
+ min_len: int = 1
10
+ max_len: int = 4096
11
+ temp: float = 0.8
12
+ top_k: int = 32
13
+ top_p: float = 0.6
14
+
15
+ # dataset args
16
+ search_args: SearchCallArgs = field(default_factory=SearchCallArgs)
17
+
18
+
19
+ @dataclass(kw_only=True)
20
+ class CHATTokArgs(TokArgs):
21
+ max_len: int = -1
@@ -0,0 +1,20 @@
1
+ from dataclasses import dataclass
2
+
3
+ from erictransformer.args import CallArgs, TokArgs
4
+
5
+
6
+ @dataclass(kw_only=True)
7
+ class GENTokArgs(TokArgs):
8
+ max_len: int = 1024
9
+
10
+
11
+ @dataclass(kw_only=True)
12
+ class GENCallArgs(CallArgs): # ← new canonical name
13
+ min_len: int = 1
14
+ max_len: int = 1024
15
+ temp: float = 0.8
16
+ top_k: int = 32
17
+ top_p: float = 0.6
18
+
19
+
20
+
@@ -0,0 +1,13 @@
1
+ from dataclasses import dataclass
2
+
3
+ from erictransformer.args import CallArgs, TokArgs
4
+
5
+
6
+ @dataclass(kw_only=True)
7
+ class TCTokArgs(TokArgs):
8
+ max_len: int = -1
9
+
10
+
11
+ @dataclass(kw_only=True)
12
+ class TCCallArgs(CallArgs):
13
+ pass
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+
3
+ from erictransformer.args import CallArgs, TokArgs
4
+
5
+
6
+ @dataclass(kw_only=True)
7
+ class TTCallArgs(CallArgs):
8
+ min_len: int = 1
9
+ max_len: int = 1024
10
+ temp: float = 0.8
11
+ top_k: int = 32
12
+ top_p: float = 0.6
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class TTTokArgs(TokArgs):
17
+ max_in_len: int = -1
18
+ max_out_len: int = -1
@@ -0,0 +1,6 @@
1
+ from erictransformer.eric_tasks.chat_stream_handlers.args import CHATStreamResult
2
+ from erictransformer.eric_tasks.chat_stream_handlers.default import (
3
+ DefaultStreamHandler,
4
+ )
5
+ from erictransformer.eric_tasks.chat_stream_handlers.gpt_oss import GPTOSSSMHandler
6
+ from erictransformer.eric_tasks.chat_stream_handlers.smol import SmolStreamHandler
@@ -0,0 +1,13 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ @dataclass(kw_only=True)
5
+ class CHATStreamResult:
6
+ text: str # what user sees
7
+ marker: str # the marker e.g text, special, tool
8
+ payload: Optional[dict]
9
+
10
+
11
+ @dataclass(kw_only=True)
12
+ class MarkerStrings:
13
+ pass
@@ -0,0 +1,19 @@
1
+ from typing import Union
2
+
3
+ from transformers import PreTrainedTokenizer
4
+
5
+ from erictransformer.eric_tasks.chat_stream_handlers.args import CHATStreamResult
6
+ from erictransformer.eric_tasks.chat_stream_handlers.stream_handler import (
7
+ StreamHandler,
8
+ )
9
+
10
+
11
+ class DefaultStreamHandler(StreamHandler):
12
+ def __init__(self, tokenizer: PreTrainedTokenizer):
13
+ super().__init__(tokenizer=tokenizer)
14
+
15
+ def step(self, token_str: str) -> Union[None, CHATStreamResult]:
16
+ if not token_str:
17
+ return None
18
+
19
+ return CHATStreamResult(text=token_str, marker="text", payload={})
@@ -0,0 +1,147 @@
1
+ from dataclasses import dataclass
2
+ from typing import Union
3
+ import logging
4
+
5
+ from transformers import PreTrainedTokenizer
6
+
7
+ from erictransformer.eric_tasks.chat_stream_handlers.args import (
8
+ CHATStreamResult,
9
+ MarkerStrings,
10
+ )
11
+ from erictransformer.eric_tasks.chat_stream_handlers.stream_handler import (
12
+ StreamHandler,
13
+ )
14
+
15
+
16
+ @dataclass(kw_only=True)
17
+ class GPTOSSMarkerStrings(MarkerStrings):
18
+ start: str = "<|start|>"
19
+ end: str = "<|end|>"
20
+ message: str = "<|message|>"
21
+ channel: str = "<|channel|>"
22
+ constrain: str = "<|constrain|>"
23
+ return_token: str = "<|return|>"
24
+ call: str = "<|call|>"
25
+
26
+
27
+ class GPTOSSSMHandler(StreamHandler):
28
+ def __init__(self, tokenizer: PreTrainedTokenizer, logger: logging.Logger):
29
+ self.markers = GPTOSSMarkerStrings()
30
+ self.current_channel = ""
31
+ self.in_thinking = False
32
+ self.in_tool = False
33
+ self.change_channel = False
34
+ self.just_received_message = False
35
+ self.start_just_happened = False
36
+ self.tool_strings = []
37
+ self.logger = logger
38
+ super().__init__(tokenizer=tokenizer)
39
+
40
+ def _reset_state(self):
41
+ self.current_channel = ""
42
+ self.in_thinking = False
43
+ self.in_tool = False
44
+ self.change_channel = False
45
+ self.just_received_message = False
46
+ self.start_just_happened = False
47
+ self.tool_strings = []
48
+
49
+ def step(self, token_str: str) -> Union[None, CHATStreamResult]:
50
+ if not token_str:
51
+ return None
52
+
53
+ stripped_token_str = token_str.strip()
54
+
55
+ self.in_thinking = False
56
+ #### SPECIAL TOKENS ####
57
+
58
+ # Case: <|start|>
59
+ if stripped_token_str == self.markers.start:
60
+ self._reset_state()
61
+ self.start_just_happened = True
62
+ return CHATStreamResult(
63
+ text=stripped_token_str, marker="special", payload={}
64
+ )
65
+
66
+ # Case: <|channel|>
67
+ if stripped_token_str == self.markers.channel:
68
+ self.change_channel = True
69
+ # next round
70
+ return CHATStreamResult(
71
+ text=stripped_token_str, marker="special", payload={}
72
+ )
73
+
74
+ # Case: <|message|>
75
+ if stripped_token_str == self.markers.message:
76
+
77
+ return CHATStreamResult(
78
+ text=stripped_token_str, marker="special", payload={}
79
+ )
80
+
81
+ # Case: <|return|>
82
+ if stripped_token_str == self.markers.return_token:
83
+ # do nothing. Streaming is over.
84
+ self._reset_state()
85
+ return CHATStreamResult(
86
+ text=stripped_token_str, marker="special", payload={}
87
+ )
88
+
89
+ # Case: <|constrain|>
90
+ if stripped_token_str == self.markers.constrain:
91
+ # do nothing. we always assume json
92
+ return CHATStreamResult(
93
+ text=stripped_token_str, marker="special", payload={}
94
+ )
95
+
96
+ # Case: <|end|>
97
+ if stripped_token_str == self.markers.end:
98
+ temp_current_channel = self.current_channel
99
+ self._reset_state()
100
+ if temp_current_channel == "analysis":
101
+ return CHATStreamResult(
102
+ text=stripped_token_str, marker="think_end", payload={}
103
+ )
104
+ return CHATStreamResult(
105
+ text=stripped_token_str, marker="special", payload={}
106
+ )
107
+
108
+ #### NON SPECIAL TOKENS ####
109
+
110
+ if self.start_just_happened:
111
+ # we don't do anything with this header token
112
+ self.start_just_happened = False
113
+
114
+ if stripped_token_str in {"assistant", "user", "system"}:
115
+ return CHATStreamResult(text=token_str, marker="special", payload={})
116
+
117
+ if self.change_channel:
118
+ self.change_channel = False
119
+
120
+ if self.current_channel == "commentary":
121
+ self.logger.warning("The 'commentary' channel is not supported. Falling back to 'analysis'. ")
122
+ self.current_channel = "analysis"
123
+
124
+ if stripped_token_str not in ("analysis", "final"):
125
+ # fall back to final
126
+ self.logger.warning(f"Unexpected channel '{stripped_token_str}'. Falling back to 'final'.")
127
+ self.current_channel = "final"
128
+ else:
129
+ self.current_channel = stripped_token_str
130
+
131
+ if stripped_token_str == "analysis":
132
+ return CHATStreamResult(
133
+ text=stripped_token_str, marker="think_start", payload={}
134
+ )
135
+
136
+ return CHATStreamResult(
137
+ text=stripped_token_str, marker="special", payload={}
138
+ )
139
+ if self.current_channel == "analysis":
140
+ # just return the text
141
+ return CHATStreamResult(text=token_str, marker="thinking", payload={})
142
+
143
+ if self.current_channel == "final":
144
+ # just return the text
145
+ return CHATStreamResult(text=token_str, marker="text", payload={})
146
+
147
+ self._reset_state()
@@ -0,0 +1,81 @@
1
+ from dataclasses import dataclass
2
+ from typing import Union
3
+ import logging
4
+
5
+ from transformers import PreTrainedTokenizer
6
+
7
+ from erictransformer.eric_tasks.chat_stream_handlers.args import (
8
+ CHATStreamResult,
9
+ MarkerStrings,
10
+ )
11
+ from erictransformer.eric_tasks.chat_stream_handlers.stream_handler import (
12
+ StreamHandler,
13
+ )
14
+
15
+
16
+ @dataclass(kw_only=True)
17
+ class SMOLMarkerStrings(MarkerStrings):
18
+ think_start: str
19
+ think_end: str
20
+ tool_start: str
21
+ tool_end: str
22
+
23
+
24
+ class SmolStreamHandler(StreamHandler):
25
+ def __init__(self, tokenizer: PreTrainedTokenizer, logger: logging.Logger):
26
+ self.markers: SMOLMarkerStrings = SMOLMarkerStrings(
27
+ think_start="<think>",
28
+ think_end="</think>",
29
+ tool_start="<tool_call>", # not supported
30
+ tool_end="</tool_call>", # not supported
31
+ )
32
+
33
+ self.marker_key_tuple = (
34
+ self.markers.think_start,
35
+ self.markers.think_end,
36
+ self.markers.tool_start,
37
+ self.markers.tool_end,
38
+ )
39
+
40
+ self.in_thinking = False
41
+
42
+ self.logger = logger
43
+
44
+ super().__init__(tokenizer=tokenizer)
45
+
46
+ def step(self, token_str: str) -> Union[None, CHATStreamResult]:
47
+ if not token_str:
48
+ return None
49
+
50
+ stripped_token_str = token_str.strip()
51
+
52
+ if stripped_token_str not in self.marker_key_tuple:
53
+ if self.in_thinking:
54
+ return CHATStreamResult(text=token_str, marker="thinking", payload={})
55
+
56
+ if stripped_token_str in self.special_tokens:
57
+ return CHATStreamResult(
58
+ text=stripped_token_str, marker="special", payload={})
59
+
60
+ else:
61
+ if token_str.endswith(self.eos_token):
62
+ token_str = token_str[: -len(self.eos_token)]
63
+ return CHATStreamResult(text=token_str, marker="text", payload={})
64
+
65
+ elif stripped_token_str == self.markers.think_start:
66
+ self.in_thinking = True
67
+ return CHATStreamResult(
68
+ text=stripped_token_str, marker="think_start", payload={}
69
+ )
70
+
71
+ elif stripped_token_str == self.markers.think_end:
72
+ self.in_thinking = False
73
+ return CHATStreamResult(
74
+ text=stripped_token_str, marker="think_end", payload={}
75
+ )
76
+
77
+ elif stripped_token_str in (self.markers.tool_start, self.markers.tool_end):
78
+ self.logger.warning("Tool calling is not supported but a tool token was generated")
79
+ return None
80
+
81
+ return None
@@ -0,0 +1,17 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Union
3
+
4
+ from transformers import PreTrainedTokenizer
5
+
6
+ from erictransformer.eric_tasks.chat_stream_handlers.args import CHATStreamResult
7
+
8
+
9
+ class StreamHandler(ABC):
10
+ def __init__(self, tokenizer: PreTrainedTokenizer):
11
+ self.tokenizer = tokenizer
12
+ self.special_tokens = tokenizer.all_special_tokens
13
+ self.eos_token = self.tokenizer.eos_token
14
+
15
+ @abstractmethod
16
+ def step(self, token_str: str) -> Union[None, CHATStreamResult]:
17
+ pass
@@ -0,0 +1 @@
1
+ from erictransformer.eric_tasks.chat_templates.convert import map_chat_roles
@@ -0,0 +1,67 @@
1
+ from erictransformer.exceptions import EricInputError
2
+
3
+
4
+ def map_chat_roles(messages: list, model_name: str, model_type: str):
5
+ role_map = get_role_map_for_model(model_name)
6
+ mapped_messages = []
7
+
8
+ if model_type == "smollm3":
9
+ mapped_messages.append({"role": "system", "content": "/think"})
10
+
11
+ for m in messages:
12
+ original_role = m.get("role")
13
+ mapped_role = role_map.get(original_role)
14
+
15
+ if mapped_role is None:
16
+ raise EricInputError(
17
+ f"Unsupported role '{original_role}' for model '{model_name}'"
18
+ )
19
+
20
+ mapped_messages.append({"role": mapped_role, "content": m["content"]})
21
+
22
+ return mapped_messages
23
+
24
+
25
+ def get_role_map_for_model(model_name: str):
26
+ name = model_name.lower()
27
+ # Standard/default families
28
+ if any(
29
+ k in name
30
+ for k in (
31
+ "granite",
32
+ "smol",
33
+ "llama",
34
+ "mistral",
35
+ "mixtral",
36
+ "hermes",
37
+ "dialogpt",
38
+ "openchat",
39
+ "chatml",
40
+ )
41
+ ):
42
+ return {
43
+ "system": "system",
44
+ "user": "user",
45
+ "assistant": "assistant",
46
+ "tool": "tool",
47
+ }
48
+
49
+ if "gemma" in name:
50
+ return {"system": "user", "user": "user", "assistant": "model", "tool": "tool"}
51
+
52
+ # Families where 'system' should be treated as 'user'
53
+ if "falcon" in name or "vicuna" in name or "alpaca" in name:
54
+ return {
55
+ "system": "user",
56
+ "user": "user",
57
+ "assistant": "assistant",
58
+ "tool": "tool",
59
+ }
60
+
61
+ # Fallback
62
+ return {
63
+ "system": "system",
64
+ "user": "user",
65
+ "assistant": "assistant",
66
+ "tool": "tool",
67
+ }