llmflowstack 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmflowstack/__init__.py +19 -0
- llmflowstack/base/__init__.py +0 -0
- llmflowstack/base/base.py +527 -0
- llmflowstack/callbacks/__init__.py +0 -0
- llmflowstack/callbacks/log_collector.py +21 -0
- llmflowstack/callbacks/stop_on_token.py +16 -0
- llmflowstack/models/GPT_OSS.py +265 -0
- llmflowstack/models/Gemma.py +247 -0
- llmflowstack/models/LLaMA3.py +213 -0
- llmflowstack/models/__init__.py +9 -0
- llmflowstack/rag/__iinit__.py +5 -0
- llmflowstack/rag/pipeline.py +114 -0
- llmflowstack/schemas/__init__.py +9 -0
- llmflowstack/schemas/params.py +39 -0
- llmflowstack/utils/__init__.py +11 -0
- llmflowstack/utils/evaluation_methods.py +92 -0
- llmflowstack/utils/exceptions.py +2 -0
- llmflowstack/utils/generation_utils.py +30 -0
- llmflowstack-1.0.0.dist-info/METADATA +229 -0
- llmflowstack-1.0.0.dist-info/RECORD +22 -0
- llmflowstack-1.0.0.dist-info/WHEEL +4 -0
- llmflowstack-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
import threading
|
|
3
|
+
from functools import partial
|
|
4
|
+
from time import time
|
|
5
|
+
from typing import Any, Generator, Iterator, Literal, TypedDict, cast
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from openai_harmony import HarmonyEncodingName, load_harmony_encoding
|
|
9
|
+
from transformers import (AutoTokenizer, StoppingCriteriaList,
|
|
10
|
+
TextIteratorStreamer)
|
|
11
|
+
from transformers.models.gpt_oss import GptOssForCausalLM
|
|
12
|
+
from transformers.utils.quantization_config import Mxfp4Config
|
|
13
|
+
|
|
14
|
+
from llmflowstack.base.base import BaseModel
|
|
15
|
+
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
16
|
+
from llmflowstack.schemas.params import GenerationParams
|
|
17
|
+
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
18
|
+
from llmflowstack.utils.generation_utils import create_generation_params
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GPTOSSInput(TypedDict):
|
|
22
|
+
input_text: str
|
|
23
|
+
system_message: str | None
|
|
24
|
+
developer_message: str | None
|
|
25
|
+
expected_answer: str | None
|
|
26
|
+
reasoning_message: str | None
|
|
27
|
+
reasoning_level: Literal["Low", "Medium", "High"] | None
|
|
28
|
+
|
|
29
|
+
class GPT_OSS(BaseModel):
|
|
30
|
+
model: GptOssForCausalLM | None = None
|
|
31
|
+
reasoning_level: Literal["Low", "Medium", "High"] = "Low"
|
|
32
|
+
question_fields = ["input_text", "developer_message", "system_message"]
|
|
33
|
+
answer_fields = ["expected_answer", "reasoning_message"]
|
|
34
|
+
|
|
35
|
+
def _set_generation_stopping_tokens(
|
|
36
|
+
self,
|
|
37
|
+
tokens: list[int]
|
|
38
|
+
) -> None:
|
|
39
|
+
if not self.tokenizer:
|
|
40
|
+
self._log("Could not set stop tokens - generation may not work...", "WARNING")
|
|
41
|
+
return None
|
|
42
|
+
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
|
43
|
+
particular_tokens = encoding.stop_tokens_for_assistant_actions()
|
|
44
|
+
self.stop_token_ids = particular_tokens + tokens
|
|
45
|
+
|
|
46
|
+
def _load_model(
|
|
47
|
+
self,
|
|
48
|
+
checkpoint: str,
|
|
49
|
+
quantization: Literal["8bit", "4bit"] | bool | None = False
|
|
50
|
+
) -> None:
|
|
51
|
+
if quantization:
|
|
52
|
+
self.model_is_quantized = True
|
|
53
|
+
quantization_config = Mxfp4Config(dequantize=False)
|
|
54
|
+
else:
|
|
55
|
+
quantization_config = Mxfp4Config(dequantize=True)
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
self.model = GptOssForCausalLM.from_pretrained(
|
|
59
|
+
checkpoint,
|
|
60
|
+
quantization_config=quantization_config,
|
|
61
|
+
dtype="auto",
|
|
62
|
+
device_map="auto",
|
|
63
|
+
attn_implementation="eager",
|
|
64
|
+
)
|
|
65
|
+
except Exception as _:
|
|
66
|
+
self._log("Error trying to load the model. Defaulting to load without quantization...", "WARNING")
|
|
67
|
+
self.model = GptOssForCausalLM.from_pretrained(
|
|
68
|
+
checkpoint,
|
|
69
|
+
dtype="auto",
|
|
70
|
+
device_map="auto",
|
|
71
|
+
attn_implementation="eager"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _build_input(
|
|
75
|
+
self,
|
|
76
|
+
input_text: str,
|
|
77
|
+
expected_answer: str | None = None,
|
|
78
|
+
system_message: str | None = None,
|
|
79
|
+
reasoning_level: Literal["Low", "Medium", "High"] | None = None,
|
|
80
|
+
reasoning_message: str | None = None,
|
|
81
|
+
developer_message: str | None = None
|
|
82
|
+
) -> str:
|
|
83
|
+
if not self.tokenizer:
|
|
84
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
85
|
+
|
|
86
|
+
reasoning = reasoning_level
|
|
87
|
+
if reasoning is None:
|
|
88
|
+
reasoning = self.reasoning_level
|
|
89
|
+
|
|
90
|
+
system_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\n\nReasoning: {reasoning}\n\n{system_message or ""}# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>"
|
|
91
|
+
|
|
92
|
+
developer_text = ""
|
|
93
|
+
if developer_message:
|
|
94
|
+
developer_text = f"<|start|>developer<|message|># Instructions\n\n{developer_message or ""}<|end|>"
|
|
95
|
+
|
|
96
|
+
assistant_text = ""
|
|
97
|
+
if reasoning_message:
|
|
98
|
+
assistant_text += f"<|start|>assistant<|channel|>analysis<|message|>{reasoning_message}<|end|>"
|
|
99
|
+
|
|
100
|
+
if expected_answer:
|
|
101
|
+
assistant_text += f"<|start|>assistant<|channel|>final<|message|>{expected_answer}<|return|>"
|
|
102
|
+
|
|
103
|
+
return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{input_text}<|end|>{assistant_text}""")
|
|
104
|
+
|
|
105
|
+
def build_input(
|
|
106
|
+
self,
|
|
107
|
+
input_text: str,
|
|
108
|
+
system_message: str | None = None,
|
|
109
|
+
developer_message: str | None = None,
|
|
110
|
+
expected_answer: str | None = None,
|
|
111
|
+
reasoning_message: str | None = None,
|
|
112
|
+
reasoning_level: Literal["Low", "Medium", "High"] | None = None
|
|
113
|
+
) -> GPTOSSInput:
|
|
114
|
+
if not self.tokenizer:
|
|
115
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
"input_text": input_text,
|
|
119
|
+
"developer_message": developer_message,
|
|
120
|
+
"system_message": system_message,
|
|
121
|
+
"reasoning_level": reasoning_level,
|
|
122
|
+
"expected_answer": expected_answer,
|
|
123
|
+
"reasoning_message": reasoning_message
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
def set_reasoning_level(
|
|
127
|
+
self,
|
|
128
|
+
level: Literal["Low", "Medium", "High"]
|
|
129
|
+
) -> None:
|
|
130
|
+
self.reasoning_level = level
|
|
131
|
+
|
|
132
|
+
def generate(
|
|
133
|
+
self,
|
|
134
|
+
input: GPTOSSInput | str,
|
|
135
|
+
params: GenerationParams | None = None
|
|
136
|
+
) -> str | None:
|
|
137
|
+
if self.model is None or self.tokenizer is None:
|
|
138
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
self._log(f"Processing received input...'")
|
|
142
|
+
|
|
143
|
+
if params is None:
|
|
144
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
145
|
+
elif params.max_new_tokens is None:
|
|
146
|
+
params.max_new_tokens = 32768
|
|
147
|
+
|
|
148
|
+
generation_params = create_generation_params(params)
|
|
149
|
+
self.model.generation_config = generation_params
|
|
150
|
+
|
|
151
|
+
model_input = None
|
|
152
|
+
if isinstance(input, str):
|
|
153
|
+
model_input = self._build_input(
|
|
154
|
+
input_text=input
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
model_input = self._build_input(
|
|
158
|
+
input_text=input["input_text"],
|
|
159
|
+
developer_message=input.get("developer_message", None),
|
|
160
|
+
system_message=input.get("system_message", None),
|
|
161
|
+
reasoning_level=input.get("reasoning_level", None)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
tokenized_input = self._tokenize(model_input)
|
|
165
|
+
|
|
166
|
+
input_ids, attention_mask = tokenized_input
|
|
167
|
+
|
|
168
|
+
self.model.eval()
|
|
169
|
+
self.model.gradient_checkpointing_disable()
|
|
170
|
+
start = time()
|
|
171
|
+
|
|
172
|
+
with torch.no_grad():
|
|
173
|
+
outputs = self.model.generate(
|
|
174
|
+
input_ids=input_ids,
|
|
175
|
+
attention_mask=attention_mask,
|
|
176
|
+
use_cache=True,
|
|
177
|
+
eos_token_id=None,
|
|
178
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
answer = self.tokenizer.decode(outputs[0])
|
|
182
|
+
|
|
183
|
+
end = time()
|
|
184
|
+
total_time = end - start
|
|
185
|
+
|
|
186
|
+
self._log(f"Response generated in {total_time:.4f} seconds")
|
|
187
|
+
|
|
188
|
+
start = answer.rfind("<|message|>")
|
|
189
|
+
if start == -1:
|
|
190
|
+
return ""
|
|
191
|
+
|
|
192
|
+
start += len("<|message|>")
|
|
193
|
+
|
|
194
|
+
end = answer.find("<|return|>", start)
|
|
195
|
+
if end == -1:
|
|
196
|
+
end = len(answer)
|
|
197
|
+
|
|
198
|
+
return answer[start:end].strip()
|
|
199
|
+
|
|
200
|
+
def generate_stream(
|
|
201
|
+
self,
|
|
202
|
+
input: GPTOSSInput | str,
|
|
203
|
+
params: GenerationParams | None = None
|
|
204
|
+
) -> Iterator[str]:
|
|
205
|
+
if self.model is None or self.tokenizer is None:
|
|
206
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
207
|
+
if False:
|
|
208
|
+
yield ""
|
|
209
|
+
return
|
|
210
|
+
|
|
211
|
+
if params is None:
|
|
212
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
213
|
+
elif params.max_new_tokens is None:
|
|
214
|
+
params.max_new_tokens = 32768
|
|
215
|
+
|
|
216
|
+
generation_params = create_generation_params(params)
|
|
217
|
+
self.model.generation_config = generation_params
|
|
218
|
+
|
|
219
|
+
if isinstance(input, str):
|
|
220
|
+
model_input = self._build_input(
|
|
221
|
+
input_text=input
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
model_input = self._build_input(
|
|
225
|
+
input_text=input["input_text"],
|
|
226
|
+
developer_message=input.get("developer_message"),
|
|
227
|
+
system_message=input.get("system_message"),
|
|
228
|
+
reasoning_level=input.get("reasoning_level")
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
tokenized_input = self._tokenize(model_input)
|
|
232
|
+
input_ids, attention_mask = tokenized_input
|
|
233
|
+
|
|
234
|
+
streamer = TextIteratorStreamer(
|
|
235
|
+
cast(AutoTokenizer, self.tokenizer),
|
|
236
|
+
skip_prompt=True,
|
|
237
|
+
skip_special_tokens=True
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
generate_fn = partial(
|
|
241
|
+
self.model.generate,
|
|
242
|
+
input_ids=input_ids,
|
|
243
|
+
attention_mask=attention_mask,
|
|
244
|
+
use_cache=True,
|
|
245
|
+
eos_token_id=None,
|
|
246
|
+
streamer=streamer,
|
|
247
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
thread = threading.Thread(target=generate_fn)
|
|
251
|
+
thread.start()
|
|
252
|
+
|
|
253
|
+
done_thinking = False
|
|
254
|
+
buffer = ""
|
|
255
|
+
|
|
256
|
+
for new_text in streamer:
|
|
257
|
+
buffer += new_text
|
|
258
|
+
|
|
259
|
+
if "final" in buffer:
|
|
260
|
+
done_thinking = True
|
|
261
|
+
buffer = buffer.split("final", 1)[1]
|
|
262
|
+
|
|
263
|
+
if done_thinking:
|
|
264
|
+
yield buffer
|
|
265
|
+
buffer = ""
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
import threading
|
|
3
|
+
from functools import partial
|
|
4
|
+
from time import time
|
|
5
|
+
from typing import Iterator, Literal, TypedDict, cast
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import (AutoTokenizer, StoppingCriteriaList,
|
|
9
|
+
TextIteratorStreamer)
|
|
10
|
+
from transformers.models.gemma3 import Gemma3ForCausalLM
|
|
11
|
+
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
12
|
+
|
|
13
|
+
from llmflowstack.base.base import BaseModel
|
|
14
|
+
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
15
|
+
from llmflowstack.schemas.params import GenerationParams
|
|
16
|
+
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
17
|
+
from llmflowstack.utils.generation_utils import create_generation_params
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GemmaInput(TypedDict):
|
|
21
|
+
input_text: str
|
|
22
|
+
expected_answer: str | None
|
|
23
|
+
system_message: str | None
|
|
24
|
+
|
|
25
|
+
class Gemma(BaseModel):
|
|
26
|
+
model: Gemma3ForCausalLM | None = None
|
|
27
|
+
can_think = False
|
|
28
|
+
question_fields = ["input_text", "system_message"]
|
|
29
|
+
answer_fields = ["expected_answer"]
|
|
30
|
+
|
|
31
|
+
def _set_generation_stopping_tokens(
|
|
32
|
+
self,
|
|
33
|
+
tokens: list[int]
|
|
34
|
+
) -> None:
|
|
35
|
+
if not self.tokenizer:
|
|
36
|
+
self._log("Could not set stop tokens - generation may not work...", "WARNING")
|
|
37
|
+
return None
|
|
38
|
+
particular_tokens = self.tokenizer.encode("<end_of_turn>")
|
|
39
|
+
self.stop_token_ids = tokens + particular_tokens
|
|
40
|
+
|
|
41
|
+
def _load_model(
|
|
42
|
+
self,
|
|
43
|
+
checkpoint: str,
|
|
44
|
+
quantization: Literal["8bit", "4bit"] | bool | None = None
|
|
45
|
+
) -> None:
|
|
46
|
+
quantization_config = None
|
|
47
|
+
if quantization == "4bit":
|
|
48
|
+
quantization_config = BitsAndBytesConfig(
|
|
49
|
+
load_in_4bit=True
|
|
50
|
+
)
|
|
51
|
+
if quantization == "8bit":
|
|
52
|
+
quantization_config = BitsAndBytesConfig(
|
|
53
|
+
load_in_8bit=True
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.model = Gemma3ForCausalLM.from_pretrained(
|
|
57
|
+
checkpoint,
|
|
58
|
+
quantization_config=quantization_config,
|
|
59
|
+
dtype="auto",
|
|
60
|
+
device_map="auto",
|
|
61
|
+
attn_implementation="eager"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _build_input(
|
|
65
|
+
self,
|
|
66
|
+
input_text: str,
|
|
67
|
+
expected_answer: str | None = None,
|
|
68
|
+
system_message: str | None = None
|
|
69
|
+
) -> str:
|
|
70
|
+
if not self.tokenizer:
|
|
71
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
72
|
+
|
|
73
|
+
if not system_message:
|
|
74
|
+
system_message = ""
|
|
75
|
+
if self.can_think:
|
|
76
|
+
system_message += f"think silently if needed. {system_message}"
|
|
77
|
+
|
|
78
|
+
if system_message:
|
|
79
|
+
system_message = f"{system_message}\n"
|
|
80
|
+
|
|
81
|
+
answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
|
|
82
|
+
|
|
83
|
+
return textwrap.dedent(
|
|
84
|
+
f"<start_of_turn>user"
|
|
85
|
+
f"{system_message}\n{input_text}<end_of_turn>\n"
|
|
86
|
+
f"<start_of_turn>model\n"
|
|
87
|
+
f"{answer}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def build_input(
|
|
91
|
+
self,
|
|
92
|
+
input_text: str,
|
|
93
|
+
expected_answer: str | None = None,
|
|
94
|
+
system_message: str | None = None
|
|
95
|
+
) -> GemmaInput:
|
|
96
|
+
if not self.tokenizer:
|
|
97
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
98
|
+
|
|
99
|
+
return {
|
|
100
|
+
"input_text": input_text,
|
|
101
|
+
"expected_answer": expected_answer,
|
|
102
|
+
"system_message": system_message
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def set_can_think(self, value: bool) -> None:
|
|
106
|
+
self.can_think = value
|
|
107
|
+
|
|
108
|
+
def generate(
|
|
109
|
+
self,
|
|
110
|
+
input: GemmaInput | str,
|
|
111
|
+
params: GenerationParams | None = None,
|
|
112
|
+
) -> str | None:
|
|
113
|
+
if self.model is None or self.tokenizer is None:
|
|
114
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
self._log(f"Processing received input...'")
|
|
118
|
+
|
|
119
|
+
if params is None:
|
|
120
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
121
|
+
elif params.max_new_tokens is None:
|
|
122
|
+
params.max_new_tokens = 32768
|
|
123
|
+
|
|
124
|
+
generation_params = create_generation_params(params)
|
|
125
|
+
self.model.generation_config = generation_params
|
|
126
|
+
|
|
127
|
+
model_input = None
|
|
128
|
+
if isinstance(input, str):
|
|
129
|
+
model_input = self._build_input(
|
|
130
|
+
input_text=input
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
model_input = self._build_input(
|
|
134
|
+
input_text=input["input_text"],
|
|
135
|
+
system_message=input["system_message"]
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
tokenized_input = self._tokenize(model_input)
|
|
139
|
+
|
|
140
|
+
input_ids, attention_mask = tokenized_input
|
|
141
|
+
|
|
142
|
+
self.model.eval()
|
|
143
|
+
self.model.gradient_checkpointing_disable()
|
|
144
|
+
start = time()
|
|
145
|
+
|
|
146
|
+
with torch.no_grad():
|
|
147
|
+
outputs = self.model.generate(
|
|
148
|
+
input_ids=input_ids,
|
|
149
|
+
attention_mask=attention_mask,
|
|
150
|
+
use_cache=True,
|
|
151
|
+
eos_token_id=None,
|
|
152
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
answer = self.tokenizer.decode(outputs[0])
|
|
156
|
+
|
|
157
|
+
end = time()
|
|
158
|
+
total_time = end - start
|
|
159
|
+
|
|
160
|
+
self._log(f"Response generated in {total_time:.4f} seconds")
|
|
161
|
+
|
|
162
|
+
start = answer.rfind("<unused95>")
|
|
163
|
+
if start == -1:
|
|
164
|
+
start = answer.rfind("<start_of_turn>model")
|
|
165
|
+
start = start + len("<start_of_turn>model")
|
|
166
|
+
else:
|
|
167
|
+
start = start + len("<unused95>")
|
|
168
|
+
|
|
169
|
+
end = answer.find("<end_of_turn>", start)
|
|
170
|
+
if end == -1:
|
|
171
|
+
end = len(answer)
|
|
172
|
+
|
|
173
|
+
return answer[start:end].strip().replace("<eos>", "")
|
|
174
|
+
|
|
175
|
+
def generate_stream(
|
|
176
|
+
self,
|
|
177
|
+
input: GemmaInput | str,
|
|
178
|
+
params: GenerationParams | None = None
|
|
179
|
+
) -> Iterator[str]:
|
|
180
|
+
if self.model is None or self.tokenizer is None:
|
|
181
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
182
|
+
if False:
|
|
183
|
+
yield ""
|
|
184
|
+
return
|
|
185
|
+
|
|
186
|
+
if params is None:
|
|
187
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
188
|
+
elif params.max_new_tokens is None:
|
|
189
|
+
params.max_new_tokens = 32768
|
|
190
|
+
|
|
191
|
+
generation_params = create_generation_params(params)
|
|
192
|
+
self.model.generation_config = generation_params
|
|
193
|
+
|
|
194
|
+
if isinstance(input, str):
|
|
195
|
+
model_input = self._build_input(
|
|
196
|
+
input_text=input
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
model_input = self._build_input(
|
|
200
|
+
input_text=input["input_text"],
|
|
201
|
+
system_message=input.get("system_message")
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
tokenized_input = self._tokenize(model_input)
|
|
205
|
+
input_ids, attention_mask = tokenized_input
|
|
206
|
+
|
|
207
|
+
streamer = TextIteratorStreamer(
|
|
208
|
+
cast(AutoTokenizer, self.tokenizer),
|
|
209
|
+
skip_prompt=True,
|
|
210
|
+
skip_special_tokens=True
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
generate_fn = partial(
|
|
214
|
+
self.model.generate,
|
|
215
|
+
input_ids=input_ids,
|
|
216
|
+
attention_mask=attention_mask,
|
|
217
|
+
use_cache=True,
|
|
218
|
+
eos_token_id=None,
|
|
219
|
+
streamer=streamer,
|
|
220
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
thread = threading.Thread(target=generate_fn)
|
|
224
|
+
thread.start()
|
|
225
|
+
|
|
226
|
+
buffer = ""
|
|
227
|
+
is_thinking = None
|
|
228
|
+
|
|
229
|
+
for new_text in streamer:
|
|
230
|
+
buffer += new_text
|
|
231
|
+
|
|
232
|
+
if is_thinking is None:
|
|
233
|
+
if len(buffer.split()) > 5:
|
|
234
|
+
is_thinking = False
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
lower_buffer = buffer.lower()
|
|
238
|
+
if lower_buffer.find("thought") != -1 or lower_buffer.find("<unused94>") != -1:
|
|
239
|
+
is_thinking = True
|
|
240
|
+
continue
|
|
241
|
+
elif not is_thinking:
|
|
242
|
+
yield buffer
|
|
243
|
+
buffer = ""
|
|
244
|
+
else:
|
|
245
|
+
if buffer.find("<unused95>") != -1:
|
|
246
|
+
is_thinking = False
|
|
247
|
+
buffer = buffer.split("<unused95>", 1)[1]
|