llmflowstack 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,265 @@
1
+ import textwrap
2
+ import threading
3
+ from functools import partial
4
+ from time import time
5
+ from typing import Any, Generator, Iterator, Literal, TypedDict, cast
6
+
7
+ import torch
8
+ from openai_harmony import HarmonyEncodingName, load_harmony_encoding
9
+ from transformers import (AutoTokenizer, StoppingCriteriaList,
10
+ TextIteratorStreamer)
11
+ from transformers.models.gpt_oss import GptOssForCausalLM
12
+ from transformers.utils.quantization_config import Mxfp4Config
13
+
14
+ from llmflowstack.base.base import BaseModel
15
+ from llmflowstack.callbacks.stop_on_token import StopOnToken
16
+ from llmflowstack.schemas.params import GenerationParams
17
+ from llmflowstack.utils.exceptions import MissingEssentialProp
18
+ from llmflowstack.utils.generation_utils import create_generation_params
19
+
20
+
21
+ class GPTOSSInput(TypedDict):
22
+ input_text: str
23
+ system_message: str | None
24
+ developer_message: str | None
25
+ expected_answer: str | None
26
+ reasoning_message: str | None
27
+ reasoning_level: Literal["Low", "Medium", "High"] | None
28
+
29
+ class GPT_OSS(BaseModel):
30
+ model: GptOssForCausalLM | None = None
31
+ reasoning_level: Literal["Low", "Medium", "High"] = "Low"
32
+ question_fields = ["input_text", "developer_message", "system_message"]
33
+ answer_fields = ["expected_answer", "reasoning_message"]
34
+
35
+ def _set_generation_stopping_tokens(
36
+ self,
37
+ tokens: list[int]
38
+ ) -> None:
39
+ if not self.tokenizer:
40
+ self._log("Could not set stop tokens - generation may not work...", "WARNING")
41
+ return None
42
+ encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
43
+ particular_tokens = encoding.stop_tokens_for_assistant_actions()
44
+ self.stop_token_ids = particular_tokens + tokens
45
+
46
+ def _load_model(
47
+ self,
48
+ checkpoint: str,
49
+ quantization: Literal["8bit", "4bit"] | bool | None = False
50
+ ) -> None:
51
+ if quantization:
52
+ self.model_is_quantized = True
53
+ quantization_config = Mxfp4Config(dequantize=False)
54
+ else:
55
+ quantization_config = Mxfp4Config(dequantize=True)
56
+
57
+ try:
58
+ self.model = GptOssForCausalLM.from_pretrained(
59
+ checkpoint,
60
+ quantization_config=quantization_config,
61
+ dtype="auto",
62
+ device_map="auto",
63
+ attn_implementation="eager",
64
+ )
65
+ except Exception as _:
66
+ self._log("Error trying to load the model. Defaulting to load without quantization...", "WARNING")
67
+ self.model = GptOssForCausalLM.from_pretrained(
68
+ checkpoint,
69
+ dtype="auto",
70
+ device_map="auto",
71
+ attn_implementation="eager"
72
+ )
73
+
74
+ def _build_input(
75
+ self,
76
+ input_text: str,
77
+ expected_answer: str | None = None,
78
+ system_message: str | None = None,
79
+ reasoning_level: Literal["Low", "Medium", "High"] | None = None,
80
+ reasoning_message: str | None = None,
81
+ developer_message: str | None = None
82
+ ) -> str:
83
+ if not self.tokenizer:
84
+ raise MissingEssentialProp("Could not find tokenizer.")
85
+
86
+ reasoning = reasoning_level
87
+ if reasoning is None:
88
+ reasoning = self.reasoning_level
89
+
90
+ system_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\n\nReasoning: {reasoning}\n\n{system_message or ""}# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>"
91
+
92
+ developer_text = ""
93
+ if developer_message:
94
+ developer_text = f"<|start|>developer<|message|># Instructions\n\n{developer_message or ""}<|end|>"
95
+
96
+ assistant_text = ""
97
+ if reasoning_message:
98
+ assistant_text += f"<|start|>assistant<|channel|>analysis<|message|>{reasoning_message}<|end|>"
99
+
100
+ if expected_answer:
101
+ assistant_text += f"<|start|>assistant<|channel|>final<|message|>{expected_answer}<|return|>"
102
+
103
+ return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{input_text}<|end|>{assistant_text}""")
104
+
105
+ def build_input(
106
+ self,
107
+ input_text: str,
108
+ system_message: str | None = None,
109
+ developer_message: str | None = None,
110
+ expected_answer: str | None = None,
111
+ reasoning_message: str | None = None,
112
+ reasoning_level: Literal["Low", "Medium", "High"] | None = None
113
+ ) -> GPTOSSInput:
114
+ if not self.tokenizer:
115
+ raise MissingEssentialProp("Could not find tokenizer.")
116
+
117
+ return {
118
+ "input_text": input_text,
119
+ "developer_message": developer_message,
120
+ "system_message": system_message,
121
+ "reasoning_level": reasoning_level,
122
+ "expected_answer": expected_answer,
123
+ "reasoning_message": reasoning_message
124
+ }
125
+
126
+ def set_reasoning_level(
127
+ self,
128
+ level: Literal["Low", "Medium", "High"]
129
+ ) -> None:
130
+ self.reasoning_level = level
131
+
132
+ def generate(
133
+ self,
134
+ input: GPTOSSInput | str,
135
+ params: GenerationParams | None = None
136
+ ) -> str | None:
137
+ if self.model is None or self.tokenizer is None:
138
+ self._log("Model or Tokenizer missing", "WARNING")
139
+ return None
140
+
141
+ self._log(f"Processing received input...'")
142
+
143
+ if params is None:
144
+ params = GenerationParams(max_new_tokens=32768)
145
+ elif params.max_new_tokens is None:
146
+ params.max_new_tokens = 32768
147
+
148
+ generation_params = create_generation_params(params)
149
+ self.model.generation_config = generation_params
150
+
151
+ model_input = None
152
+ if isinstance(input, str):
153
+ model_input = self._build_input(
154
+ input_text=input
155
+ )
156
+ else:
157
+ model_input = self._build_input(
158
+ input_text=input["input_text"],
159
+ developer_message=input.get("developer_message", None),
160
+ system_message=input.get("system_message", None),
161
+ reasoning_level=input.get("reasoning_level", None)
162
+ )
163
+
164
+ tokenized_input = self._tokenize(model_input)
165
+
166
+ input_ids, attention_mask = tokenized_input
167
+
168
+ self.model.eval()
169
+ self.model.gradient_checkpointing_disable()
170
+ start = time()
171
+
172
+ with torch.no_grad():
173
+ outputs = self.model.generate(
174
+ input_ids=input_ids,
175
+ attention_mask=attention_mask,
176
+ use_cache=True,
177
+ eos_token_id=None,
178
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
179
+ )
180
+
181
+ answer = self.tokenizer.decode(outputs[0])
182
+
183
+ end = time()
184
+ total_time = end - start
185
+
186
+ self._log(f"Response generated in {total_time:.4f} seconds")
187
+
188
+ start = answer.rfind("<|message|>")
189
+ if start == -1:
190
+ return ""
191
+
192
+ start += len("<|message|>")
193
+
194
+ end = answer.find("<|return|>", start)
195
+ if end == -1:
196
+ end = len(answer)
197
+
198
+ return answer[start:end].strip()
199
+
200
+ def generate_stream(
201
+ self,
202
+ input: GPTOSSInput | str,
203
+ params: GenerationParams | None = None
204
+ ) -> Iterator[str]:
205
+ if self.model is None or self.tokenizer is None:
206
+ self._log("Model or Tokenizer missing", "WARNING")
207
+ if False:
208
+ yield ""
209
+ return
210
+
211
+ if params is None:
212
+ params = GenerationParams(max_new_tokens=32768)
213
+ elif params.max_new_tokens is None:
214
+ params.max_new_tokens = 32768
215
+
216
+ generation_params = create_generation_params(params)
217
+ self.model.generation_config = generation_params
218
+
219
+ if isinstance(input, str):
220
+ model_input = self._build_input(
221
+ input_text=input
222
+ )
223
+ else:
224
+ model_input = self._build_input(
225
+ input_text=input["input_text"],
226
+ developer_message=input.get("developer_message"),
227
+ system_message=input.get("system_message"),
228
+ reasoning_level=input.get("reasoning_level")
229
+ )
230
+
231
+ tokenized_input = self._tokenize(model_input)
232
+ input_ids, attention_mask = tokenized_input
233
+
234
+ streamer = TextIteratorStreamer(
235
+ cast(AutoTokenizer, self.tokenizer),
236
+ skip_prompt=True,
237
+ skip_special_tokens=True
238
+ )
239
+
240
+ generate_fn = partial(
241
+ self.model.generate,
242
+ input_ids=input_ids,
243
+ attention_mask=attention_mask,
244
+ use_cache=True,
245
+ eos_token_id=None,
246
+ streamer=streamer,
247
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
248
+ )
249
+
250
+ thread = threading.Thread(target=generate_fn)
251
+ thread.start()
252
+
253
+ done_thinking = False
254
+ buffer = ""
255
+
256
+ for new_text in streamer:
257
+ buffer += new_text
258
+
259
+ if "final" in buffer:
260
+ done_thinking = True
261
+ buffer = buffer.split("final", 1)[1]
262
+
263
+ if done_thinking:
264
+ yield buffer
265
+ buffer = ""
@@ -0,0 +1,247 @@
1
+ import textwrap
2
+ import threading
3
+ from functools import partial
4
+ from time import time
5
+ from typing import Iterator, Literal, TypedDict, cast
6
+
7
+ import torch
8
+ from transformers import (AutoTokenizer, StoppingCriteriaList,
9
+ TextIteratorStreamer)
10
+ from transformers.models.gemma3 import Gemma3ForCausalLM
11
+ from transformers.utils.quantization_config import BitsAndBytesConfig
12
+
13
+ from llmflowstack.base.base import BaseModel
14
+ from llmflowstack.callbacks.stop_on_token import StopOnToken
15
+ from llmflowstack.schemas.params import GenerationParams
16
+ from llmflowstack.utils.exceptions import MissingEssentialProp
17
+ from llmflowstack.utils.generation_utils import create_generation_params
18
+
19
+
20
+ class GemmaInput(TypedDict):
21
+ input_text: str
22
+ expected_answer: str | None
23
+ system_message: str | None
24
+
25
+ class Gemma(BaseModel):
26
+ model: Gemma3ForCausalLM | None = None
27
+ can_think = False
28
+ question_fields = ["input_text", "system_message"]
29
+ answer_fields = ["expected_answer"]
30
+
31
+ def _set_generation_stopping_tokens(
32
+ self,
33
+ tokens: list[int]
34
+ ) -> None:
35
+ if not self.tokenizer:
36
+ self._log("Could not set stop tokens - generation may not work...", "WARNING")
37
+ return None
38
+ particular_tokens = self.tokenizer.encode("<end_of_turn>")
39
+ self.stop_token_ids = tokens + particular_tokens
40
+
41
+ def _load_model(
42
+ self,
43
+ checkpoint: str,
44
+ quantization: Literal["8bit", "4bit"] | bool | None = None
45
+ ) -> None:
46
+ quantization_config = None
47
+ if quantization == "4bit":
48
+ quantization_config = BitsAndBytesConfig(
49
+ load_in_4bit=True
50
+ )
51
+ if quantization == "8bit":
52
+ quantization_config = BitsAndBytesConfig(
53
+ load_in_8bit=True
54
+ )
55
+
56
+ self.model = Gemma3ForCausalLM.from_pretrained(
57
+ checkpoint,
58
+ quantization_config=quantization_config,
59
+ dtype="auto",
60
+ device_map="auto",
61
+ attn_implementation="eager"
62
+ )
63
+
64
+ def _build_input(
65
+ self,
66
+ input_text: str,
67
+ expected_answer: str | None = None,
68
+ system_message: str | None = None
69
+ ) -> str:
70
+ if not self.tokenizer:
71
+ raise MissingEssentialProp("Could not find tokenizer.")
72
+
73
+ if not system_message:
74
+ system_message = ""
75
+ if self.can_think:
76
+ system_message += f"think silently if needed. {system_message}"
77
+
78
+ if system_message:
79
+ system_message = f"{system_message}\n"
80
+
81
+ answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
82
+
83
+ return textwrap.dedent(
84
+ f"<start_of_turn>user"
85
+ f"{system_message}\n{input_text}<end_of_turn>\n"
86
+ f"<start_of_turn>model\n"
87
+ f"{answer}"
88
+ )
89
+
90
+ def build_input(
91
+ self,
92
+ input_text: str,
93
+ expected_answer: str | None = None,
94
+ system_message: str | None = None
95
+ ) -> GemmaInput:
96
+ if not self.tokenizer:
97
+ raise MissingEssentialProp("Could not find tokenizer.")
98
+
99
+ return {
100
+ "input_text": input_text,
101
+ "expected_answer": expected_answer,
102
+ "system_message": system_message
103
+ }
104
+
105
+ def set_can_think(self, value: bool) -> None:
106
+ self.can_think = value
107
+
108
+ def generate(
109
+ self,
110
+ input: GemmaInput | str,
111
+ params: GenerationParams | None = None,
112
+ ) -> str | None:
113
+ if self.model is None or self.tokenizer is None:
114
+ self._log("Model or Tokenizer missing", "WARNING")
115
+ return None
116
+
117
+ self._log(f"Processing received input...'")
118
+
119
+ if params is None:
120
+ params = GenerationParams(max_new_tokens=32768)
121
+ elif params.max_new_tokens is None:
122
+ params.max_new_tokens = 32768
123
+
124
+ generation_params = create_generation_params(params)
125
+ self.model.generation_config = generation_params
126
+
127
+ model_input = None
128
+ if isinstance(input, str):
129
+ model_input = self._build_input(
130
+ input_text=input
131
+ )
132
+ else:
133
+ model_input = self._build_input(
134
+ input_text=input["input_text"],
135
+ system_message=input["system_message"]
136
+ )
137
+
138
+ tokenized_input = self._tokenize(model_input)
139
+
140
+ input_ids, attention_mask = tokenized_input
141
+
142
+ self.model.eval()
143
+ self.model.gradient_checkpointing_disable()
144
+ start = time()
145
+
146
+ with torch.no_grad():
147
+ outputs = self.model.generate(
148
+ input_ids=input_ids,
149
+ attention_mask=attention_mask,
150
+ use_cache=True,
151
+ eos_token_id=None,
152
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
153
+ )
154
+
155
+ answer = self.tokenizer.decode(outputs[0])
156
+
157
+ end = time()
158
+ total_time = end - start
159
+
160
+ self._log(f"Response generated in {total_time:.4f} seconds")
161
+
162
+ start = answer.rfind("<unused95>")
163
+ if start == -1:
164
+ start = answer.rfind("<start_of_turn>model")
165
+ start = start + len("<start_of_turn>model")
166
+ else:
167
+ start = start + len("<unused95>")
168
+
169
+ end = answer.find("<end_of_turn>", start)
170
+ if end == -1:
171
+ end = len(answer)
172
+
173
+ return answer[start:end].strip().replace("<eos>", "")
174
+
175
+ def generate_stream(
176
+ self,
177
+ input: GemmaInput | str,
178
+ params: GenerationParams | None = None
179
+ ) -> Iterator[str]:
180
+ if self.model is None or self.tokenizer is None:
181
+ self._log("Model or Tokenizer missing", "WARNING")
182
+ if False:
183
+ yield ""
184
+ return
185
+
186
+ if params is None:
187
+ params = GenerationParams(max_new_tokens=32768)
188
+ elif params.max_new_tokens is None:
189
+ params.max_new_tokens = 32768
190
+
191
+ generation_params = create_generation_params(params)
192
+ self.model.generation_config = generation_params
193
+
194
+ if isinstance(input, str):
195
+ model_input = self._build_input(
196
+ input_text=input
197
+ )
198
+ else:
199
+ model_input = self._build_input(
200
+ input_text=input["input_text"],
201
+ system_message=input.get("system_message")
202
+ )
203
+
204
+ tokenized_input = self._tokenize(model_input)
205
+ input_ids, attention_mask = tokenized_input
206
+
207
+ streamer = TextIteratorStreamer(
208
+ cast(AutoTokenizer, self.tokenizer),
209
+ skip_prompt=True,
210
+ skip_special_tokens=True
211
+ )
212
+
213
+ generate_fn = partial(
214
+ self.model.generate,
215
+ input_ids=input_ids,
216
+ attention_mask=attention_mask,
217
+ use_cache=True,
218
+ eos_token_id=None,
219
+ streamer=streamer,
220
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
221
+ )
222
+
223
+ thread = threading.Thread(target=generate_fn)
224
+ thread.start()
225
+
226
+ buffer = ""
227
+ is_thinking = None
228
+
229
+ for new_text in streamer:
230
+ buffer += new_text
231
+
232
+ if is_thinking is None:
233
+ if len(buffer.split()) > 5:
234
+ is_thinking = False
235
+ continue
236
+
237
+ lower_buffer = buffer.lower()
238
+ if lower_buffer.find("thought") != -1 or lower_buffer.find("<unused94>") != -1:
239
+ is_thinking = True
240
+ continue
241
+ elif not is_thinking:
242
+ yield buffer
243
+ buffer = ""
244
+ else:
245
+ if buffer.find("<unused95>") != -1:
246
+ is_thinking = False
247
+ buffer = buffer.split("<unused95>", 1)[1]