erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from erictransformer.eric_tasks import (
|
|
2
|
+
CHATCallArgs,
|
|
3
|
+
CHATResult,
|
|
4
|
+
CHATStreamResult,
|
|
5
|
+
CHATTokArgs,
|
|
6
|
+
GENCallArgs,
|
|
7
|
+
GENResult,
|
|
8
|
+
GENTokArgs,
|
|
9
|
+
EricChat,
|
|
10
|
+
EricGeneration,
|
|
11
|
+
EricTextClassification,
|
|
12
|
+
EricTextToText,
|
|
13
|
+
TCCallArgs,
|
|
14
|
+
TCResult,
|
|
15
|
+
TCTokArgs,
|
|
16
|
+
TTCallArgs,
|
|
17
|
+
TTResult,
|
|
18
|
+
TTTokArgs
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from erictransformer.args import EricTrainArgs, EricEvalArgs
|
|
22
|
+
from erictransformer.eric_transformer import EricTransformer
|
|
23
|
+
from erictransformer.loops import EvalResult, TrainResult
|
|
24
|
+
|
|
25
|
+
__all__ = []
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from mlx_lm import load, stream_generate
|
|
29
|
+
from mlx_lm.sample_utils import make_sampler
|
|
30
|
+
from mlx_lm.utils import save_model
|
|
31
|
+
mlx_enabled = True
|
|
32
|
+
except ImportError as err:
|
|
33
|
+
mlx_enabled = False
|
|
34
|
+
|
|
35
|
+
if mlx_enabled:
|
|
36
|
+
try:
|
|
37
|
+
from .eric_tasks.eric_chat_mlx import EricChatMLX
|
|
38
|
+
except Exception:
|
|
39
|
+
# Optional: log or ignore
|
|
40
|
+
pass
|
|
41
|
+
else:
|
|
42
|
+
__all__.append("EricChatMLX")
|
|
43
|
+
|
|
44
|
+
name = "erictransformer"
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
@dataclass(kw_only=True)
|
|
4
|
+
class EricTrainArgs:
|
|
5
|
+
# Learning parameters
|
|
6
|
+
lr: float = 2e-5
|
|
7
|
+
bs: int = 1
|
|
8
|
+
eval_bs: int = 0 # when 0 uses bs
|
|
9
|
+
epochs: int = 1
|
|
10
|
+
gas: int = 1
|
|
11
|
+
optim: str = "adamw" # options adamw and sgd
|
|
12
|
+
lr_sched: str = "constant" # other option: warmup_then_decay
|
|
13
|
+
|
|
14
|
+
# Action steps
|
|
15
|
+
eval_steps: int = 256 # if 0 no evaluating will be done
|
|
16
|
+
log_steps: int = 256 # if 0 no logging will be done
|
|
17
|
+
checkpoint_steps: int = -1 # if -1 no checkpointing will be done
|
|
18
|
+
save_best: bool = False # saves the model with the lowest eval loss
|
|
19
|
+
|
|
20
|
+
# Misc
|
|
21
|
+
out_dir: str = "eric_transformer/"
|
|
22
|
+
run_name: str = ""
|
|
23
|
+
seed: int = 42
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(kw_only=True)
|
|
27
|
+
class TokArgs:
|
|
28
|
+
bs: int = 1024
|
|
29
|
+
max_cases: int = -1
|
|
30
|
+
shards: int = 1
|
|
31
|
+
procs: int = -1
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(kw_only=True)
|
|
36
|
+
class EricEvalArgs:
|
|
37
|
+
bs: int = 1
|
|
38
|
+
out_dir: str = "eric_transformer/"
|
|
39
|
+
run_name: str = ""
|
|
40
|
+
seed: int = 42
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(kw_only=True)
|
|
44
|
+
class CallArgs:
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(kw_only=True)
|
|
49
|
+
class CallResult:
|
|
50
|
+
pass
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from erictransformer.eric_tasks.args import (
|
|
2
|
+
CHATCallArgs,
|
|
3
|
+
CHATTokArgs,
|
|
4
|
+
GENCallArgs,
|
|
5
|
+
GENTokArgs,
|
|
6
|
+
TCCallArgs,
|
|
7
|
+
TCTokArgs,
|
|
8
|
+
TTCallArgs,
|
|
9
|
+
TTTokArgs,
|
|
10
|
+
)
|
|
11
|
+
from erictransformer.eric_tasks.chat_stream_handlers import CHATStreamResult
|
|
12
|
+
from erictransformer.eric_tasks.eric_chat import EricChat
|
|
13
|
+
from erictransformer.eric_tasks.eric_generation import EricGeneration
|
|
14
|
+
from erictransformer.eric_tasks.eric_text_classification import (
|
|
15
|
+
EricTextClassification,
|
|
16
|
+
)
|
|
17
|
+
from erictransformer.eric_tasks.eric_text_to_text import (
|
|
18
|
+
EricTextToText,
|
|
19
|
+
TTStreamResult,
|
|
20
|
+
)
|
|
21
|
+
from erictransformer.eric_tasks.results import (
|
|
22
|
+
CHATResult,
|
|
23
|
+
GENResult,
|
|
24
|
+
TCResult,
|
|
25
|
+
TTResult
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
__all__ = []
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from mlx_lm import load, stream_generate
|
|
33
|
+
from mlx_lm.sample_utils import make_sampler
|
|
34
|
+
from mlx_lm.utils import save_model
|
|
35
|
+
mlx_enabled = True
|
|
36
|
+
except ImportError as err:
|
|
37
|
+
mlx_enabled = False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
if mlx_enabled:
|
|
41
|
+
try:
|
|
42
|
+
from .eric_chat_mlx import EricChatMLX
|
|
43
|
+
except Exception:
|
|
44
|
+
# Optional: log or ignore
|
|
45
|
+
pass
|
|
46
|
+
else:
|
|
47
|
+
__all__.append("EricChatMLX")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from erictransformer.eric_tasks.args.eric_chat_args import (
|
|
2
|
+
CHATCallArgs,
|
|
3
|
+
CHATTokArgs,
|
|
4
|
+
)
|
|
5
|
+
from erictransformer.eric_tasks.args.eric_generation_args import (
|
|
6
|
+
GENCallArgs,
|
|
7
|
+
GENTokArgs,
|
|
8
|
+
)
|
|
9
|
+
from erictransformer.eric_tasks.args.eric_text_classification_args import (
|
|
10
|
+
TCCallArgs,
|
|
11
|
+
TCTokArgs,
|
|
12
|
+
)
|
|
13
|
+
from erictransformer.eric_tasks.args.eric_text_to_text_args import (
|
|
14
|
+
TTCallArgs,
|
|
15
|
+
TTTokArgs,
|
|
16
|
+
)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from ericsearch import SearchCallArgs
|
|
3
|
+
|
|
4
|
+
from erictransformer.args import CallArgs, TokArgs
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(kw_only=True)
|
|
8
|
+
class CHATCallArgs(CallArgs):
|
|
9
|
+
min_len: int = 1
|
|
10
|
+
max_len: int = 4096
|
|
11
|
+
temp: float = 0.8
|
|
12
|
+
top_k: int = 32
|
|
13
|
+
top_p: float = 0.6
|
|
14
|
+
|
|
15
|
+
# dataset args
|
|
16
|
+
search_args: SearchCallArgs = field(default_factory=SearchCallArgs)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(kw_only=True)
|
|
20
|
+
class CHATTokArgs(TokArgs):
|
|
21
|
+
max_len: int = -1
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from erictransformer.args import CallArgs, TokArgs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass(kw_only=True)
|
|
7
|
+
class GENTokArgs(TokArgs):
|
|
8
|
+
max_len: int = 1024
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(kw_only=True)
|
|
12
|
+
class GENCallArgs(CallArgs): # ← new canonical name
|
|
13
|
+
min_len: int = 1
|
|
14
|
+
max_len: int = 1024
|
|
15
|
+
temp: float = 0.8
|
|
16
|
+
top_k: int = 32
|
|
17
|
+
top_p: float = 0.6
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from erictransformer.args import CallArgs, TokArgs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass(kw_only=True)
|
|
7
|
+
class TTCallArgs(CallArgs):
|
|
8
|
+
min_len: int = 1
|
|
9
|
+
max_len: int = 1024
|
|
10
|
+
temp: float = 0.8
|
|
11
|
+
top_k: int = 32
|
|
12
|
+
top_p: float = 0.6
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(kw_only=True)
|
|
16
|
+
class TTTokArgs(TokArgs):
|
|
17
|
+
max_in_len: int = -1
|
|
18
|
+
max_out_len: int = -1
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from erictransformer.eric_tasks.chat_stream_handlers.args import CHATStreamResult
|
|
2
|
+
from erictransformer.eric_tasks.chat_stream_handlers.default import (
|
|
3
|
+
DefaultStreamHandler,
|
|
4
|
+
)
|
|
5
|
+
from erictransformer.eric_tasks.chat_stream_handlers.gpt_oss import GPTOSSSMHandler
|
|
6
|
+
from erictransformer.eric_tasks.chat_stream_handlers.smol import SmolStreamHandler
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
@dataclass(kw_only=True)
|
|
5
|
+
class CHATStreamResult:
|
|
6
|
+
text: str # what user sees
|
|
7
|
+
marker: str # the marker e.g text, special, tool
|
|
8
|
+
payload: Optional[dict]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(kw_only=True)
|
|
12
|
+
class MarkerStrings:
|
|
13
|
+
pass
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from transformers import PreTrainedTokenizer
|
|
4
|
+
|
|
5
|
+
from erictransformer.eric_tasks.chat_stream_handlers.args import CHATStreamResult
|
|
6
|
+
from erictransformer.eric_tasks.chat_stream_handlers.stream_handler import (
|
|
7
|
+
StreamHandler,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DefaultStreamHandler(StreamHandler):
|
|
12
|
+
def __init__(self, tokenizer: PreTrainedTokenizer):
|
|
13
|
+
super().__init__(tokenizer=tokenizer)
|
|
14
|
+
|
|
15
|
+
def step(self, token_str: str) -> Union[None, CHATStreamResult]:
|
|
16
|
+
if not token_str:
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
return CHATStreamResult(text=token_str, marker="text", payload={})
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Union
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from transformers import PreTrainedTokenizer
|
|
6
|
+
|
|
7
|
+
from erictransformer.eric_tasks.chat_stream_handlers.args import (
|
|
8
|
+
CHATStreamResult,
|
|
9
|
+
MarkerStrings,
|
|
10
|
+
)
|
|
11
|
+
from erictransformer.eric_tasks.chat_stream_handlers.stream_handler import (
|
|
12
|
+
StreamHandler,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(kw_only=True)
|
|
17
|
+
class GPTOSSMarkerStrings(MarkerStrings):
|
|
18
|
+
start: str = "<|start|>"
|
|
19
|
+
end: str = "<|end|>"
|
|
20
|
+
message: str = "<|message|>"
|
|
21
|
+
channel: str = "<|channel|>"
|
|
22
|
+
constrain: str = "<|constrain|>"
|
|
23
|
+
return_token: str = "<|return|>"
|
|
24
|
+
call: str = "<|call|>"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GPTOSSSMHandler(StreamHandler):
|
|
28
|
+
def __init__(self, tokenizer: PreTrainedTokenizer, logger: logging.Logger):
|
|
29
|
+
self.markers = GPTOSSMarkerStrings()
|
|
30
|
+
self.current_channel = ""
|
|
31
|
+
self.in_thinking = False
|
|
32
|
+
self.in_tool = False
|
|
33
|
+
self.change_channel = False
|
|
34
|
+
self.just_received_message = False
|
|
35
|
+
self.start_just_happened = False
|
|
36
|
+
self.tool_strings = []
|
|
37
|
+
self.logger = logger
|
|
38
|
+
super().__init__(tokenizer=tokenizer)
|
|
39
|
+
|
|
40
|
+
def _reset_state(self):
|
|
41
|
+
self.current_channel = ""
|
|
42
|
+
self.in_thinking = False
|
|
43
|
+
self.in_tool = False
|
|
44
|
+
self.change_channel = False
|
|
45
|
+
self.just_received_message = False
|
|
46
|
+
self.start_just_happened = False
|
|
47
|
+
self.tool_strings = []
|
|
48
|
+
|
|
49
|
+
def step(self, token_str: str) -> Union[None, CHATStreamResult]:
|
|
50
|
+
if not token_str:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
stripped_token_str = token_str.strip()
|
|
54
|
+
|
|
55
|
+
self.in_thinking = False
|
|
56
|
+
#### SPECIAL TOKENS ####
|
|
57
|
+
|
|
58
|
+
# Case: <|start|>
|
|
59
|
+
if stripped_token_str == self.markers.start:
|
|
60
|
+
self._reset_state()
|
|
61
|
+
self.start_just_happened = True
|
|
62
|
+
return CHATStreamResult(
|
|
63
|
+
text=stripped_token_str, marker="special", payload={}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Case: <|channel|>
|
|
67
|
+
if stripped_token_str == self.markers.channel:
|
|
68
|
+
self.change_channel = True
|
|
69
|
+
# next round
|
|
70
|
+
return CHATStreamResult(
|
|
71
|
+
text=stripped_token_str, marker="special", payload={}
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Case: <|message|>
|
|
75
|
+
if stripped_token_str == self.markers.message:
|
|
76
|
+
|
|
77
|
+
return CHATStreamResult(
|
|
78
|
+
text=stripped_token_str, marker="special", payload={}
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Case: <|return|>
|
|
82
|
+
if stripped_token_str == self.markers.return_token:
|
|
83
|
+
# do nothing. Streaming is over.
|
|
84
|
+
self._reset_state()
|
|
85
|
+
return CHATStreamResult(
|
|
86
|
+
text=stripped_token_str, marker="special", payload={}
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Case: <|constrain|>
|
|
90
|
+
if stripped_token_str == self.markers.constrain:
|
|
91
|
+
# do nothing. we always assume json
|
|
92
|
+
return CHATStreamResult(
|
|
93
|
+
text=stripped_token_str, marker="special", payload={}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Case: <|end|>
|
|
97
|
+
if stripped_token_str == self.markers.end:
|
|
98
|
+
temp_current_channel = self.current_channel
|
|
99
|
+
self._reset_state()
|
|
100
|
+
if temp_current_channel == "analysis":
|
|
101
|
+
return CHATStreamResult(
|
|
102
|
+
text=stripped_token_str, marker="think_end", payload={}
|
|
103
|
+
)
|
|
104
|
+
return CHATStreamResult(
|
|
105
|
+
text=stripped_token_str, marker="special", payload={}
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
#### NON SPECIAL TOKENS ####
|
|
109
|
+
|
|
110
|
+
if self.start_just_happened:
|
|
111
|
+
# we don't do anything with this header token
|
|
112
|
+
self.start_just_happened = False
|
|
113
|
+
|
|
114
|
+
if stripped_token_str in {"assistant", "user", "system"}:
|
|
115
|
+
return CHATStreamResult(text=token_str, marker="special", payload={})
|
|
116
|
+
|
|
117
|
+
if self.change_channel:
|
|
118
|
+
self.change_channel = False
|
|
119
|
+
|
|
120
|
+
if self.current_channel == "commentary":
|
|
121
|
+
self.logger.warning("The 'commentary' channel is not supported. Falling back to 'analysis'. ")
|
|
122
|
+
self.current_channel = "analysis"
|
|
123
|
+
|
|
124
|
+
if stripped_token_str not in ("analysis", "final"):
|
|
125
|
+
# fall back to final
|
|
126
|
+
self.logger.warning(f"Unexpected channel '{stripped_token_str}'. Falling back to 'final'.")
|
|
127
|
+
self.current_channel = "final"
|
|
128
|
+
else:
|
|
129
|
+
self.current_channel = stripped_token_str
|
|
130
|
+
|
|
131
|
+
if stripped_token_str == "analysis":
|
|
132
|
+
return CHATStreamResult(
|
|
133
|
+
text=stripped_token_str, marker="think_start", payload={}
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return CHATStreamResult(
|
|
137
|
+
text=stripped_token_str, marker="special", payload={}
|
|
138
|
+
)
|
|
139
|
+
if self.current_channel == "analysis":
|
|
140
|
+
# just return the text
|
|
141
|
+
return CHATStreamResult(text=token_str, marker="thinking", payload={})
|
|
142
|
+
|
|
143
|
+
if self.current_channel == "final":
|
|
144
|
+
# just return the text
|
|
145
|
+
return CHATStreamResult(text=token_str, marker="text", payload={})
|
|
146
|
+
|
|
147
|
+
self._reset_state()
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Union
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from transformers import PreTrainedTokenizer
|
|
6
|
+
|
|
7
|
+
from erictransformer.eric_tasks.chat_stream_handlers.args import (
|
|
8
|
+
CHATStreamResult,
|
|
9
|
+
MarkerStrings,
|
|
10
|
+
)
|
|
11
|
+
from erictransformer.eric_tasks.chat_stream_handlers.stream_handler import (
|
|
12
|
+
StreamHandler,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(kw_only=True)
|
|
17
|
+
class SMOLMarkerStrings(MarkerStrings):
|
|
18
|
+
think_start: str
|
|
19
|
+
think_end: str
|
|
20
|
+
tool_start: str
|
|
21
|
+
tool_end: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SmolStreamHandler(StreamHandler):
|
|
25
|
+
def __init__(self, tokenizer: PreTrainedTokenizer, logger: logging.Logger):
|
|
26
|
+
self.markers: SMOLMarkerStrings = SMOLMarkerStrings(
|
|
27
|
+
think_start="<think>",
|
|
28
|
+
think_end="</think>",
|
|
29
|
+
tool_start="<tool_call>", # not supported
|
|
30
|
+
tool_end="</tool_call>", # not supported
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.marker_key_tuple = (
|
|
34
|
+
self.markers.think_start,
|
|
35
|
+
self.markers.think_end,
|
|
36
|
+
self.markers.tool_start,
|
|
37
|
+
self.markers.tool_end,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self.in_thinking = False
|
|
41
|
+
|
|
42
|
+
self.logger = logger
|
|
43
|
+
|
|
44
|
+
super().__init__(tokenizer=tokenizer)
|
|
45
|
+
|
|
46
|
+
def step(self, token_str: str) -> Union[None, CHATStreamResult]:
|
|
47
|
+
if not token_str:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
stripped_token_str = token_str.strip()
|
|
51
|
+
|
|
52
|
+
if stripped_token_str not in self.marker_key_tuple:
|
|
53
|
+
if self.in_thinking:
|
|
54
|
+
return CHATStreamResult(text=token_str, marker="thinking", payload={})
|
|
55
|
+
|
|
56
|
+
if stripped_token_str in self.special_tokens:
|
|
57
|
+
return CHATStreamResult(
|
|
58
|
+
text=stripped_token_str, marker="special", payload={})
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
if token_str.endswith(self.eos_token):
|
|
62
|
+
token_str = token_str[: -len(self.eos_token)]
|
|
63
|
+
return CHATStreamResult(text=token_str, marker="text", payload={})
|
|
64
|
+
|
|
65
|
+
elif stripped_token_str == self.markers.think_start:
|
|
66
|
+
self.in_thinking = True
|
|
67
|
+
return CHATStreamResult(
|
|
68
|
+
text=stripped_token_str, marker="think_start", payload={}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
elif stripped_token_str == self.markers.think_end:
|
|
72
|
+
self.in_thinking = False
|
|
73
|
+
return CHATStreamResult(
|
|
74
|
+
text=stripped_token_str, marker="think_end", payload={}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
elif stripped_token_str in (self.markers.tool_start, self.markers.tool_end):
|
|
78
|
+
self.logger.warning("Tool calling is not supported but a tool token was generated")
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
return None
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from transformers import PreTrainedTokenizer
|
|
5
|
+
|
|
6
|
+
from erictransformer.eric_tasks.chat_stream_handlers.args import CHATStreamResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StreamHandler(ABC):
|
|
10
|
+
def __init__(self, tokenizer: PreTrainedTokenizer):
|
|
11
|
+
self.tokenizer = tokenizer
|
|
12
|
+
self.special_tokens = tokenizer.all_special_tokens
|
|
13
|
+
self.eos_token = self.tokenizer.eos_token
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def step(self, token_str: str) -> Union[None, CHATStreamResult]:
|
|
17
|
+
pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from erictransformer.eric_tasks.chat_templates.convert import map_chat_roles
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from erictransformer.exceptions import EricInputError
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def map_chat_roles(messages: list, model_name: str, model_type: str):
|
|
5
|
+
role_map = get_role_map_for_model(model_name)
|
|
6
|
+
mapped_messages = []
|
|
7
|
+
|
|
8
|
+
if model_type == "smollm3":
|
|
9
|
+
mapped_messages.append({"role": "system", "content": "/think"})
|
|
10
|
+
|
|
11
|
+
for m in messages:
|
|
12
|
+
original_role = m.get("role")
|
|
13
|
+
mapped_role = role_map.get(original_role)
|
|
14
|
+
|
|
15
|
+
if mapped_role is None:
|
|
16
|
+
raise EricInputError(
|
|
17
|
+
f"Unsupported role '{original_role}' for model '{model_name}'"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
mapped_messages.append({"role": mapped_role, "content": m["content"]})
|
|
21
|
+
|
|
22
|
+
return mapped_messages
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_role_map_for_model(model_name: str):
|
|
26
|
+
name = model_name.lower()
|
|
27
|
+
# Standard/default families
|
|
28
|
+
if any(
|
|
29
|
+
k in name
|
|
30
|
+
for k in (
|
|
31
|
+
"granite",
|
|
32
|
+
"smol",
|
|
33
|
+
"llama",
|
|
34
|
+
"mistral",
|
|
35
|
+
"mixtral",
|
|
36
|
+
"hermes",
|
|
37
|
+
"dialogpt",
|
|
38
|
+
"openchat",
|
|
39
|
+
"chatml",
|
|
40
|
+
)
|
|
41
|
+
):
|
|
42
|
+
return {
|
|
43
|
+
"system": "system",
|
|
44
|
+
"user": "user",
|
|
45
|
+
"assistant": "assistant",
|
|
46
|
+
"tool": "tool",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
if "gemma" in name:
|
|
50
|
+
return {"system": "user", "user": "user", "assistant": "model", "tool": "tool"}
|
|
51
|
+
|
|
52
|
+
# Families where 'system' should be treated as 'user'
|
|
53
|
+
if "falcon" in name or "vicuna" in name or "alpaca" in name:
|
|
54
|
+
return {
|
|
55
|
+
"system": "user",
|
|
56
|
+
"user": "user",
|
|
57
|
+
"assistant": "assistant",
|
|
58
|
+
"tool": "tool",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
# Fallback
|
|
62
|
+
return {
|
|
63
|
+
"system": "system",
|
|
64
|
+
"user": "user",
|
|
65
|
+
"assistant": "assistant",
|
|
66
|
+
"tool": "tool",
|
|
67
|
+
}
|