llmflowstack 1.0.2__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/PKG-INFO +19 -5
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/README.md +16 -3
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/__init__.py +7 -3
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/base/base.py +15 -8
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/models/GPT_OSS.py +42 -23
- llmflowstack-1.1.1/llmflowstack/models/Gemma.py +319 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/models/LLaMA3.py +39 -14
- llmflowstack-1.1.1/llmflowstack/models/LLaMA4.py +317 -0
- llmflowstack-1.0.2/llmflowstack/models/Gemma.py → llmflowstack-1.1.1/llmflowstack/models/MedGemma.py +43 -22
- llmflowstack-1.1.1/llmflowstack/models/__init__.py +13 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/rag/pipeline.py +8 -1
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/pyproject.toml +3 -2
- llmflowstack-1.0.2/llmflowstack/models/__init__.py +0 -9
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/.github/workflows/python-publish.yml +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/.gitignore +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/LICENSE +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/base/__init__.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/callbacks/__init__.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/callbacks/log_collector.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/callbacks/stop_on_token.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/rag/__iinit__.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/schemas/__init__.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/schemas/params.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/__init__.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/evaluation_methods.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/exceptions.py +0 -0
- {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/generation_utils.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: llmflowstack
|
|
3
|
-
Version: 1.
|
|
4
|
-
Summary: LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components.
|
|
3
|
+
Version: 1.1.1
|
|
4
|
+
Summary: LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma, ...). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components.
|
|
5
5
|
Author-email: Gustavo Henrique Ferreira Cruz <gustavohferreiracruz@gmail.com>
|
|
6
6
|
License: MIT
|
|
7
7
|
License-File: LICENSE
|
|
@@ -14,6 +14,7 @@ Requires-Dist: colorama
|
|
|
14
14
|
Requires-Dist: datasets
|
|
15
15
|
Requires-Dist: evaluate
|
|
16
16
|
Requires-Dist: huggingface-hub
|
|
17
|
+
Requires-Dist: kernels
|
|
17
18
|
Requires-Dist: langchain-chroma
|
|
18
19
|
Requires-Dist: langchain-community
|
|
19
20
|
Requires-Dist: nltk
|
|
@@ -56,18 +57,31 @@ This framework is designed to provide flexibility when working with different op
|
|
|
56
57
|
|
|
57
58
|
- [`GPT-OSS 20B`](https://huggingface.co/openai/gpt-oss-20b)
|
|
58
59
|
- [`GPT-OSS 120B`](https://huggingface.co/openai/gpt-oss-120b)
|
|
60
|
+
> Fine-Tuning, DAPT and Inference Available
|
|
59
61
|
|
|
60
|
-
- **LLaMA**
|
|
62
|
+
- **LLaMA 3**
|
|
61
63
|
|
|
62
64
|
- [`LLaMA 3.1 8B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
|
|
63
65
|
- [`LLaMA 3.1 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct)
|
|
64
66
|
- [`LLaMA 3.3 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)
|
|
65
67
|
- [`LLaMA 3.3 405B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct)
|
|
68
|
+
> Fine-Tuning, DAPT and Inference Available
|
|
69
|
+
|
|
70
|
+
- **LLaMA 4**
|
|
71
|
+
|
|
72
|
+
- [`LLaMA 4 Scout - Instruct`](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
|
|
73
|
+
> DAPT and Inference Available
|
|
66
74
|
|
|
67
75
|
- **Gemma**
|
|
68
|
-
- [`MedGemma 27B Text - It`](https://huggingface.co/google/medgemma-27b-text-it)
|
|
69
76
|
|
|
70
|
-
|
|
77
|
+
- [`Gemma 3 27B - Instruct`](https://huggingface.co/google/gemma-3-27b-it)
|
|
78
|
+
> DAPT and Inference Available
|
|
79
|
+
|
|
80
|
+
- **MedGemma**
|
|
81
|
+
- [`MedGemma 27B Text - Instruct`](https://huggingface.co/google/medgemma-27b-text-it)
|
|
82
|
+
> Fine-Tuning, DAPT and Inference Available
|
|
83
|
+
|
|
84
|
+
> Other architectures based on those **may** function correctly.
|
|
71
85
|
|
|
72
86
|
---
|
|
73
87
|
|
|
@@ -20,18 +20,31 @@ This framework is designed to provide flexibility when working with different op
|
|
|
20
20
|
|
|
21
21
|
- [`GPT-OSS 20B`](https://huggingface.co/openai/gpt-oss-20b)
|
|
22
22
|
- [`GPT-OSS 120B`](https://huggingface.co/openai/gpt-oss-120b)
|
|
23
|
+
> Fine-Tuning, DAPT and Inference Available
|
|
23
24
|
|
|
24
|
-
- **LLaMA**
|
|
25
|
+
- **LLaMA 3**
|
|
25
26
|
|
|
26
27
|
- [`LLaMA 3.1 8B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
|
|
27
28
|
- [`LLaMA 3.1 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct)
|
|
28
29
|
- [`LLaMA 3.3 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)
|
|
29
30
|
- [`LLaMA 3.3 405B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct)
|
|
31
|
+
> Fine-Tuning, DAPT and Inference Available
|
|
32
|
+
|
|
33
|
+
- **LLaMA 4**
|
|
34
|
+
|
|
35
|
+
- [`LLaMA 4 Scout - Instruct`](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
|
|
36
|
+
> DAPT and Inference Available
|
|
30
37
|
|
|
31
38
|
- **Gemma**
|
|
32
|
-
- [`MedGemma 27B Text - It`](https://huggingface.co/google/medgemma-27b-text-it)
|
|
33
39
|
|
|
34
|
-
|
|
40
|
+
- [`Gemma 3 27B - Instruct`](https://huggingface.co/google/gemma-3-27b-it)
|
|
41
|
+
> DAPT and Inference Available
|
|
42
|
+
|
|
43
|
+
- **MedGemma**
|
|
44
|
+
- [`MedGemma 27B Text - Instruct`](https://huggingface.co/google/medgemma-27b-text-it)
|
|
45
|
+
> Fine-Tuning, DAPT and Inference Available
|
|
46
|
+
|
|
47
|
+
> Other architectures based on those **may** function correctly.
|
|
35
48
|
|
|
36
49
|
---
|
|
37
50
|
|
|
@@ -1,15 +1,19 @@
|
|
|
1
|
-
from .models.Gemma import
|
|
1
|
+
from .models.Gemma import Gemma3
|
|
2
2
|
from .models.GPT_OSS import GPT_OSS
|
|
3
3
|
from .models.LLaMA3 import LLaMA3
|
|
4
|
+
from .models.LLaMA4 import LLaMA4
|
|
5
|
+
from .models.MedGemma import MedGemma
|
|
4
6
|
from .rag.pipeline import RAGPipeline
|
|
5
7
|
from .schemas.params import (GenerationBeamsParams, GenerationParams,
|
|
6
8
|
GenerationSampleParams, TrainParams)
|
|
7
9
|
from .utils.evaluation_methods import text_evaluation
|
|
8
10
|
|
|
9
11
|
__all__ = [
|
|
10
|
-
"
|
|
11
|
-
"LLaMA3",
|
|
12
|
+
"Gemma3",
|
|
12
13
|
"GPT_OSS",
|
|
14
|
+
"LLaMA3",
|
|
15
|
+
"LLaMA4",
|
|
16
|
+
"MedGemma",
|
|
13
17
|
"RAGPipeline",
|
|
14
18
|
"GenerationBeamsParams",
|
|
15
19
|
"GenerationParams",
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import gc
|
|
1
2
|
import json
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -35,7 +36,7 @@ class BaseModel(ABC):
|
|
|
35
36
|
def __init__(
|
|
36
37
|
self,
|
|
37
38
|
checkpoint: str | None = None,
|
|
38
|
-
quantization: Literal["
|
|
39
|
+
quantization: Literal["4bit", "8bit"] | bool | None = None,
|
|
39
40
|
seed: int | None = None,
|
|
40
41
|
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
41
42
|
) -> None:
|
|
@@ -64,7 +65,8 @@ class BaseModel(ABC):
|
|
|
64
65
|
def _load_model(
|
|
65
66
|
self,
|
|
66
67
|
checkpoint: str,
|
|
67
|
-
|
|
68
|
+
*args: Any,
|
|
69
|
+
**kwargs: Any
|
|
68
70
|
) -> None:
|
|
69
71
|
pass
|
|
70
72
|
|
|
@@ -79,7 +81,7 @@ class BaseModel(ABC):
|
|
|
79
81
|
def load_checkpoint(
|
|
80
82
|
self,
|
|
81
83
|
checkpoint: str,
|
|
82
|
-
quantization:
|
|
84
|
+
quantization: Any
|
|
83
85
|
) -> None:
|
|
84
86
|
if self.model:
|
|
85
87
|
self._log("A model is already loaded. Attempting to reset it.", "WARNING")
|
|
@@ -223,7 +225,7 @@ class BaseModel(ABC):
|
|
|
223
225
|
self,
|
|
224
226
|
*args: Any,
|
|
225
227
|
**kwargs: Any
|
|
226
|
-
) -> str:
|
|
228
|
+
) -> str | BatchEncoding:
|
|
227
229
|
pass
|
|
228
230
|
|
|
229
231
|
def _tokenize(
|
|
@@ -282,7 +284,7 @@ class BaseModel(ABC):
|
|
|
282
284
|
output = []
|
|
283
285
|
for data in dataset:
|
|
284
286
|
complete_input = self._build_input(
|
|
285
|
-
|
|
287
|
+
data
|
|
286
288
|
)
|
|
287
289
|
output.append(complete_input)
|
|
288
290
|
|
|
@@ -403,13 +405,16 @@ class BaseModel(ABC):
|
|
|
403
405
|
def _build_input_for_fine_tune(
|
|
404
406
|
self,
|
|
405
407
|
input: dict
|
|
406
|
-
) -> dict[Literal["partial", "complete"], str]:
|
|
408
|
+
) -> dict[Literal["partial", "complete"], str | BatchEncoding]:
|
|
407
409
|
if not self.tokenizer:
|
|
408
410
|
raise MissingEssentialProp("Could not find tokenizer.")
|
|
409
411
|
|
|
410
|
-
partial = self._build_input(
|
|
412
|
+
partial = self._build_input({
|
|
413
|
+
**input,
|
|
414
|
+
"expected_answer": None
|
|
415
|
+
})
|
|
411
416
|
|
|
412
|
-
complete = self._build_input(
|
|
417
|
+
complete = self._build_input(input)
|
|
413
418
|
|
|
414
419
|
return {
|
|
415
420
|
"partial": partial,
|
|
@@ -508,6 +513,8 @@ class BaseModel(ABC):
|
|
|
508
513
|
try:
|
|
509
514
|
self._log("Trying to reset model...")
|
|
510
515
|
del self.model
|
|
516
|
+
gc.collect()
|
|
517
|
+
torch.cuda.empty_cache()
|
|
511
518
|
self.model = None
|
|
512
519
|
self.model_is_quantized = None
|
|
513
520
|
self.process_id = None
|
|
@@ -2,7 +2,7 @@ import textwrap
|
|
|
2
2
|
import threading
|
|
3
3
|
from functools import partial
|
|
4
4
|
from time import time
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Iterator, Literal, TypedDict, cast
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from openai_harmony import HarmonyEncodingName, load_harmony_encoding
|
|
@@ -32,6 +32,20 @@ class GPT_OSS(BaseModel):
|
|
|
32
32
|
question_fields = ["input_text", "developer_message", "system_message"]
|
|
33
33
|
answer_fields = ["expected_answer", "reasoning_message"]
|
|
34
34
|
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
checkpoint: str | None = None,
|
|
38
|
+
quantization: bool | None = None,
|
|
39
|
+
seed: int | None = None,
|
|
40
|
+
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
41
|
+
) -> None:
|
|
42
|
+
return super().__init__(
|
|
43
|
+
checkpoint=checkpoint,
|
|
44
|
+
quantization=quantization,
|
|
45
|
+
seed=seed,
|
|
46
|
+
log_level=log_level
|
|
47
|
+
)
|
|
48
|
+
|
|
35
49
|
def _set_generation_stopping_tokens(
|
|
36
50
|
self,
|
|
37
51
|
tokens: list[int]
|
|
@@ -46,10 +60,9 @@ class GPT_OSS(BaseModel):
|
|
|
46
60
|
def _load_model(
|
|
47
61
|
self,
|
|
48
62
|
checkpoint: str,
|
|
49
|
-
quantization:
|
|
63
|
+
quantization: bool | None = False
|
|
50
64
|
) -> None:
|
|
51
65
|
if quantization:
|
|
52
|
-
self.model_is_quantized = True
|
|
53
66
|
quantization_config = Mxfp4Config(dequantize=False)
|
|
54
67
|
else:
|
|
55
68
|
quantization_config = Mxfp4Config(dequantize=True)
|
|
@@ -70,37 +83,43 @@ class GPT_OSS(BaseModel):
|
|
|
70
83
|
device_map="auto",
|
|
71
84
|
attn_implementation="eager"
|
|
72
85
|
)
|
|
86
|
+
|
|
87
|
+
def load_checkpoint(
|
|
88
|
+
self,
|
|
89
|
+
checkpoint: str,
|
|
90
|
+
quantization: bool | None = None
|
|
91
|
+
) -> None:
|
|
92
|
+
return super().load_checkpoint(checkpoint, quantization)
|
|
73
93
|
|
|
74
94
|
def _build_input(
|
|
75
95
|
self,
|
|
76
|
-
|
|
77
|
-
expected_answer: str | None = None,
|
|
78
|
-
system_message: str | None = None,
|
|
79
|
-
reasoning_level: Literal["Low", "Medium", "High"] | None = None,
|
|
80
|
-
reasoning_message: str | None = None,
|
|
81
|
-
developer_message: str | None = None
|
|
96
|
+
data: GPTOSSInput
|
|
82
97
|
) -> str:
|
|
83
98
|
if not self.tokenizer:
|
|
84
99
|
raise MissingEssentialProp("Could not find tokenizer.")
|
|
85
100
|
|
|
86
|
-
reasoning = reasoning_level
|
|
101
|
+
reasoning = data.get("reasoning_level")
|
|
87
102
|
if reasoning is None:
|
|
88
103
|
reasoning = self.reasoning_level
|
|
89
104
|
|
|
90
|
-
|
|
105
|
+
system_message = data.get("system_message", "")
|
|
106
|
+
system_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\n\nReasoning: {reasoning}\n\n{system_message}# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>"
|
|
91
107
|
|
|
92
108
|
developer_text = ""
|
|
109
|
+
developer_message = data.get("developer_message", "")
|
|
93
110
|
if developer_message:
|
|
94
|
-
developer_text = f"<|start|>developer<|message|># Instructions\n\n{developer_message
|
|
111
|
+
developer_text = f"<|start|>developer<|message|># Instructions\n\n{developer_message}<|end|>"
|
|
95
112
|
|
|
96
113
|
assistant_text = ""
|
|
114
|
+
reasoning_message = data.get("reasoning_message", "")
|
|
97
115
|
if reasoning_message:
|
|
98
116
|
assistant_text += f"<|start|>assistant<|channel|>analysis<|message|>{reasoning_message}<|end|>"
|
|
99
117
|
|
|
118
|
+
expected_answer = data.get("expected_answer", "")
|
|
100
119
|
if expected_answer:
|
|
101
120
|
assistant_text += f"<|start|>assistant<|channel|>final<|message|>{expected_answer}<|return|>"
|
|
102
121
|
|
|
103
|
-
return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{input_text}<|end|>{assistant_text}""")
|
|
122
|
+
return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{data["input_text"]}<|end|>{assistant_text}""")
|
|
104
123
|
|
|
105
124
|
def build_input(
|
|
106
125
|
self,
|
|
@@ -150,15 +169,15 @@ class GPT_OSS(BaseModel):
|
|
|
150
169
|
|
|
151
170
|
model_input = None
|
|
152
171
|
if isinstance(input, str):
|
|
153
|
-
model_input = self.
|
|
172
|
+
model_input = self.build_input(
|
|
154
173
|
input_text=input
|
|
155
174
|
)
|
|
175
|
+
model_input = self._build_input(
|
|
176
|
+
data=model_input
|
|
177
|
+
)
|
|
156
178
|
else:
|
|
157
179
|
model_input = self._build_input(
|
|
158
|
-
|
|
159
|
-
developer_message=input.get("developer_message", None),
|
|
160
|
-
system_message=input.get("system_message", None),
|
|
161
|
-
reasoning_level=input.get("reasoning_level", None)
|
|
180
|
+
data=input
|
|
162
181
|
)
|
|
163
182
|
|
|
164
183
|
tokenized_input = self._tokenize(model_input)
|
|
@@ -217,15 +236,15 @@ class GPT_OSS(BaseModel):
|
|
|
217
236
|
self.model.generation_config = generation_params
|
|
218
237
|
|
|
219
238
|
if isinstance(input, str):
|
|
220
|
-
model_input = self.
|
|
239
|
+
model_input = self.build_input(
|
|
221
240
|
input_text=input
|
|
222
241
|
)
|
|
242
|
+
model_input = self._build_input(
|
|
243
|
+
data=model_input
|
|
244
|
+
)
|
|
223
245
|
else:
|
|
224
246
|
model_input = self._build_input(
|
|
225
|
-
|
|
226
|
-
developer_message=input.get("developer_message"),
|
|
227
|
-
system_message=input.get("system_message"),
|
|
228
|
-
reasoning_level=input.get("reasoning_level")
|
|
247
|
+
data=input
|
|
229
248
|
)
|
|
230
249
|
|
|
231
250
|
tokenized_input = self._tokenize(model_input)
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from functools import partial
|
|
3
|
+
from time import time
|
|
4
|
+
from typing import Iterator, Literal, TypedDict, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
|
|
8
|
+
StoppingCriteriaList, TextIteratorStreamer, Trainer,
|
|
9
|
+
TrainingArguments)
|
|
10
|
+
from transformers.models.gemma3 import Gemma3ForCausalLM
|
|
11
|
+
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
12
|
+
|
|
13
|
+
from llmflowstack.base.base import BaseModel
|
|
14
|
+
from llmflowstack.callbacks.log_collector import LogCollectorCallback
|
|
15
|
+
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
16
|
+
from llmflowstack.schemas.params import GenerationParams, TrainParams
|
|
17
|
+
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
18
|
+
from llmflowstack.utils.generation_utils import create_generation_params
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Gemma3Input(TypedDict):
|
|
22
|
+
input_text: str
|
|
23
|
+
expected_answer: str | None
|
|
24
|
+
system_message: str | None
|
|
25
|
+
image_paths: list[str] | None
|
|
26
|
+
|
|
27
|
+
class Gemma3(BaseModel):
|
|
28
|
+
model: Gemma3ForCausalLM | None = None
|
|
29
|
+
question_fields = ["input_text", "system_message"]
|
|
30
|
+
answer_fields = ["expected_answer"]
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
checkpoint: str | None = None,
|
|
35
|
+
quantization: Literal["4bit"] | None = None,
|
|
36
|
+
seed: int | None = None,
|
|
37
|
+
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
38
|
+
) -> None:
|
|
39
|
+
return super().__init__(
|
|
40
|
+
checkpoint=checkpoint,
|
|
41
|
+
quantization=quantization,
|
|
42
|
+
seed=seed,
|
|
43
|
+
log_level=log_level
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def _set_generation_stopping_tokens(
|
|
47
|
+
self,
|
|
48
|
+
tokens: list[int]
|
|
49
|
+
) -> None:
|
|
50
|
+
if not self.tokenizer:
|
|
51
|
+
self._log("Could not set stop tokens - generation may not work...", "WARNING")
|
|
52
|
+
return None
|
|
53
|
+
particular_tokens = self.tokenizer.encode("<end_of_turn>")
|
|
54
|
+
self.stop_token_ids = tokens + particular_tokens
|
|
55
|
+
|
|
56
|
+
def _load_model(
|
|
57
|
+
self,
|
|
58
|
+
checkpoint: str,
|
|
59
|
+
quantization: Literal["4bit"] | None = None
|
|
60
|
+
) -> None:
|
|
61
|
+
quantization_config = None
|
|
62
|
+
if quantization == "4bit":
|
|
63
|
+
quantization_config = BitsAndBytesConfig(
|
|
64
|
+
load_in_4bit=True
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.model = Gemma3ForCausalLM.from_pretrained(
|
|
68
|
+
checkpoint,
|
|
69
|
+
quantization_config=quantization_config,
|
|
70
|
+
dtype="auto",
|
|
71
|
+
device_map="auto",
|
|
72
|
+
attn_implementation="eager"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def load_checkpoint(
|
|
76
|
+
self,
|
|
77
|
+
checkpoint: str,
|
|
78
|
+
quantization: Literal['4bit'] | None = None
|
|
79
|
+
) -> None:
|
|
80
|
+
return super().load_checkpoint(checkpoint, quantization)
|
|
81
|
+
|
|
82
|
+
def _build_input(
|
|
83
|
+
self,
|
|
84
|
+
data: Gemma3Input
|
|
85
|
+
) -> str:
|
|
86
|
+
if not self.tokenizer:
|
|
87
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
88
|
+
|
|
89
|
+
system_message = data.get("system_message", "")
|
|
90
|
+
if not system_message:
|
|
91
|
+
system_message = ""
|
|
92
|
+
|
|
93
|
+
if system_message:
|
|
94
|
+
system_message = f"{system_message}\n"
|
|
95
|
+
|
|
96
|
+
expected_answer = data.get("expected_answer")
|
|
97
|
+
answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
|
|
98
|
+
|
|
99
|
+
return (
|
|
100
|
+
f"<start_of_turn>user"
|
|
101
|
+
f"{system_message}\n{data["input_text"]}<end_of_turn>\n"
|
|
102
|
+
f"<start_of_turn>model\n"
|
|
103
|
+
f"{answer}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def build_input(
|
|
107
|
+
self,
|
|
108
|
+
input_text: str,
|
|
109
|
+
system_message: str | None = None,
|
|
110
|
+
expected_answer: str | None = None,
|
|
111
|
+
image_paths: list[str] | None = None
|
|
112
|
+
) -> Gemma3Input:
|
|
113
|
+
if not self.tokenizer:
|
|
114
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
115
|
+
|
|
116
|
+
return {
|
|
117
|
+
"input_text": input_text,
|
|
118
|
+
"system_message": system_message,
|
|
119
|
+
"expected_answer": expected_answer,
|
|
120
|
+
"image_paths": image_paths
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
def dapt(
|
|
124
|
+
self,
|
|
125
|
+
train_dataset: list,
|
|
126
|
+
params: TrainParams | None = None,
|
|
127
|
+
eval_dataset: list | None = None,
|
|
128
|
+
save_at_end = True,
|
|
129
|
+
save_path: str | None = None
|
|
130
|
+
) -> None:
|
|
131
|
+
if not self.model:
|
|
132
|
+
self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
|
|
133
|
+
return None
|
|
134
|
+
if not self.tokenizer:
|
|
135
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
self._log("Starting Training")
|
|
139
|
+
|
|
140
|
+
if self.model_is_quantized:
|
|
141
|
+
self._log("Cannot traub a quantized model.", "WARNING")
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
if params is None:
|
|
145
|
+
params = TrainParams()
|
|
146
|
+
|
|
147
|
+
training_arguments = TrainingArguments(
|
|
148
|
+
num_train_epochs=params.epochs,
|
|
149
|
+
learning_rate=params.lr,
|
|
150
|
+
gradient_accumulation_steps=params.gradient_accumulation,
|
|
151
|
+
warmup_ratio=params.warmup_ratio,
|
|
152
|
+
lr_scheduler_type="cosine_with_min_lr",
|
|
153
|
+
lr_scheduler_kwargs={"min_lr_rate": 0.1},
|
|
154
|
+
output_dir=None,
|
|
155
|
+
save_strategy="no",
|
|
156
|
+
logging_steps=params.logging_steps
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if self.seed is not None:
|
|
160
|
+
training_arguments.seed = self.seed
|
|
161
|
+
|
|
162
|
+
processed_train_dataset = self._promptfy_dataset_for_dapt(train_dataset)
|
|
163
|
+
tokenized_train_dataset = self._tokenize_dataset_for_dapt(processed_train_dataset)
|
|
164
|
+
|
|
165
|
+
tokenized_eval_dataset = None
|
|
166
|
+
if eval_dataset:
|
|
167
|
+
processed_eval_dataset = self._promptfy_dataset_for_dapt(eval_dataset)
|
|
168
|
+
tokenized_eval_dataset = self._tokenize_dataset_for_dapt(processed_eval_dataset)
|
|
169
|
+
|
|
170
|
+
log_callback = LogCollectorCallback()
|
|
171
|
+
|
|
172
|
+
trainer = Trainer(
|
|
173
|
+
model=self.model,
|
|
174
|
+
train_dataset=tokenized_train_dataset,
|
|
175
|
+
eval_dataset=tokenized_eval_dataset,
|
|
176
|
+
args=training_arguments,
|
|
177
|
+
callbacks=[log_callback],
|
|
178
|
+
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
trainer.train()
|
|
182
|
+
|
|
183
|
+
if save_at_end and save_path:
|
|
184
|
+
self.save_checkpoint(
|
|
185
|
+
path=save_path
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
self._log("Finished Training")
|
|
189
|
+
|
|
190
|
+
def fine_tune(
|
|
191
|
+
self,
|
|
192
|
+
train_dataset: list,
|
|
193
|
+
params: TrainParams | None = None,
|
|
194
|
+
eval_dataset: list | None = None,
|
|
195
|
+
save_at_end = True,
|
|
196
|
+
save_path: str | None = None
|
|
197
|
+
) -> None:
|
|
198
|
+
self._log("Only 'dapt' method is available for this class. Redirecting call to it.", "WARNING")
|
|
199
|
+
return self.dapt(
|
|
200
|
+
train_dataset=train_dataset,
|
|
201
|
+
params=params,
|
|
202
|
+
eval_dataset=eval_dataset,
|
|
203
|
+
save_at_end=save_at_end,
|
|
204
|
+
save_path=save_path
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def generate(
|
|
208
|
+
self,
|
|
209
|
+
input: Gemma3Input | str,
|
|
210
|
+
params: GenerationParams | None = None,
|
|
211
|
+
) -> str | None:
|
|
212
|
+
if self.model is None or self.tokenizer is None:
|
|
213
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
self._log(f"Processing received input...'")
|
|
217
|
+
|
|
218
|
+
if params is None:
|
|
219
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
220
|
+
elif params.max_new_tokens is None:
|
|
221
|
+
params.max_new_tokens = 32768
|
|
222
|
+
|
|
223
|
+
generation_params = create_generation_params(params)
|
|
224
|
+
self.model.generation_config = generation_params
|
|
225
|
+
|
|
226
|
+
model_input = None
|
|
227
|
+
if isinstance(input, str):
|
|
228
|
+
model_input = self.build_input(
|
|
229
|
+
input_text=input
|
|
230
|
+
)
|
|
231
|
+
model_input = self._build_input(
|
|
232
|
+
data=model_input
|
|
233
|
+
)
|
|
234
|
+
else:
|
|
235
|
+
model_input = self._build_input(
|
|
236
|
+
data=input
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
tokenized_input = self._tokenize(model_input)
|
|
240
|
+
input_ids, attention_mask = tokenized_input
|
|
241
|
+
|
|
242
|
+
self.model.eval()
|
|
243
|
+
self.model.gradient_checkpointing_disable()
|
|
244
|
+
start = time()
|
|
245
|
+
|
|
246
|
+
with torch.no_grad():
|
|
247
|
+
outputs = self.model.generate(
|
|
248
|
+
input_ids=input_ids,
|
|
249
|
+
attention_mask=attention_mask,
|
|
250
|
+
use_cache=True,
|
|
251
|
+
eos_token_id=None,
|
|
252
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
end = time()
|
|
256
|
+
total_time = end - start
|
|
257
|
+
|
|
258
|
+
self._log(f"Response generated in {total_time:.4f} seconds")
|
|
259
|
+
|
|
260
|
+
response = outputs[0][input_ids.shape[1]:]
|
|
261
|
+
|
|
262
|
+
return self.tokenizer.decode(response, skip_special_tokens=True)
|
|
263
|
+
|
|
264
|
+
def generate_stream(
|
|
265
|
+
self,
|
|
266
|
+
input: Gemma3Input | str,
|
|
267
|
+
params: GenerationParams | None = None
|
|
268
|
+
) -> Iterator[str]:
|
|
269
|
+
if self.model is None or self.tokenizer is None:
|
|
270
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
271
|
+
if False:
|
|
272
|
+
yield ""
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
if params is None:
|
|
276
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
277
|
+
elif params.max_new_tokens is None:
|
|
278
|
+
params.max_new_tokens = 32768
|
|
279
|
+
|
|
280
|
+
generation_params = create_generation_params(params)
|
|
281
|
+
self.model.generation_config = generation_params
|
|
282
|
+
|
|
283
|
+
model_input = None
|
|
284
|
+
if isinstance(input, str):
|
|
285
|
+
model_input = self.build_input(
|
|
286
|
+
input_text=input
|
|
287
|
+
)
|
|
288
|
+
model_input = self._build_input(
|
|
289
|
+
data=model_input
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
model_input = self._build_input(
|
|
293
|
+
data=input
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
tokenized_input = self._tokenize(model_input)
|
|
297
|
+
input_ids, attention_mask = tokenized_input
|
|
298
|
+
|
|
299
|
+
streamer = TextIteratorStreamer(
|
|
300
|
+
cast(AutoTokenizer, self.tokenizer),
|
|
301
|
+
skip_prompt=True,
|
|
302
|
+
skip_special_tokens=True
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
generate_fn = partial(
|
|
306
|
+
self.model.generate,
|
|
307
|
+
input_ids=input_ids,
|
|
308
|
+
attention_mask=attention_mask,
|
|
309
|
+
use_cache=True,
|
|
310
|
+
eos_token_id=None,
|
|
311
|
+
streamer=streamer,
|
|
312
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
thread = threading.Thread(target=generate_fn)
|
|
316
|
+
thread.start()
|
|
317
|
+
|
|
318
|
+
for new_text in streamer:
|
|
319
|
+
yield new_text
|
|
@@ -26,6 +26,20 @@ class LLaMA3(BaseModel):
|
|
|
26
26
|
question_fields = ["input_text", "system_message"]
|
|
27
27
|
answer_fields = ["expected_answer"]
|
|
28
28
|
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
checkpoint: str | None = None,
|
|
32
|
+
quantization: Literal["4bit", "8bit"] | None = None,
|
|
33
|
+
seed: int | None = None,
|
|
34
|
+
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
35
|
+
) -> None:
|
|
36
|
+
return super().__init__(
|
|
37
|
+
checkpoint=checkpoint,
|
|
38
|
+
quantization=quantization,
|
|
39
|
+
seed=seed,
|
|
40
|
+
log_level=log_level
|
|
41
|
+
)
|
|
42
|
+
|
|
29
43
|
def _set_generation_stopping_tokens(
|
|
30
44
|
self,
|
|
31
45
|
tokens: list[int]
|
|
@@ -39,19 +53,17 @@ class LLaMA3(BaseModel):
|
|
|
39
53
|
def _load_model(
|
|
40
54
|
self,
|
|
41
55
|
checkpoint: str,
|
|
42
|
-
quantization: Literal["
|
|
56
|
+
quantization: Literal["4bit", "8bit"] | None = None
|
|
43
57
|
) -> None:
|
|
44
58
|
quantization_config = None
|
|
45
59
|
if quantization == "4bit":
|
|
46
60
|
quantization_config = BitsAndBytesConfig(
|
|
47
61
|
load_in_4bit=True
|
|
48
62
|
)
|
|
49
|
-
self.model_is_quantized = True
|
|
50
63
|
if quantization == "8bit":
|
|
51
64
|
quantization_config = BitsAndBytesConfig(
|
|
52
65
|
load_in_8bit=True
|
|
53
66
|
)
|
|
54
|
-
self.model_is_quantized = True
|
|
55
67
|
|
|
56
68
|
self.model = LlamaForCausalLM.from_pretrained(
|
|
57
69
|
checkpoint,
|
|
@@ -60,21 +72,29 @@ class LLaMA3(BaseModel):
|
|
|
60
72
|
device_map="auto",
|
|
61
73
|
attn_implementation="eager"
|
|
62
74
|
)
|
|
75
|
+
|
|
76
|
+
def load_checkpoint(
|
|
77
|
+
self,
|
|
78
|
+
checkpoint: str,
|
|
79
|
+
quantization: Literal['4bit', "8bit"] | None = None
|
|
80
|
+
) -> None:
|
|
81
|
+
return super().load_checkpoint(checkpoint, quantization)
|
|
63
82
|
|
|
64
83
|
def _build_input(
|
|
65
84
|
self,
|
|
66
|
-
|
|
67
|
-
expected_answer: str | None = None,
|
|
68
|
-
system_message: str | None = None
|
|
85
|
+
data: LLaMA3Input
|
|
69
86
|
) -> str:
|
|
70
87
|
if not self.tokenizer:
|
|
71
88
|
raise MissingEssentialProp("Could not find tokenizer.")
|
|
72
89
|
|
|
90
|
+
expected_answer = data.get("expected_answer")
|
|
73
91
|
answer = f"{expected_answer}{self.tokenizer.eos_token}" if expected_answer else ""
|
|
74
92
|
|
|
93
|
+
system_message = data.get("system_message", "")
|
|
94
|
+
|
|
75
95
|
return textwrap.dedent(
|
|
76
|
-
f"<|start_header_id|>system<|end_header_id|>{system_message
|
|
77
|
-
f"<|eot_id|><|start_header_id|>user<|end_header_id|>{input_text}\n"
|
|
96
|
+
f"<|start_header_id|>system<|end_header_id|>{system_message}\n"
|
|
97
|
+
f"<|eot_id|><|start_header_id|>user<|end_header_id|>{data["input_text"]}\n"
|
|
78
98
|
f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>{answer}"
|
|
79
99
|
)
|
|
80
100
|
|
|
@@ -120,13 +140,15 @@ class LLaMA3(BaseModel):
|
|
|
120
140
|
|
|
121
141
|
model_input = None
|
|
122
142
|
if isinstance(input, str):
|
|
123
|
-
model_input = self.
|
|
143
|
+
model_input = self.build_input(
|
|
124
144
|
input_text=input
|
|
125
145
|
)
|
|
146
|
+
model_input = self._build_input(
|
|
147
|
+
data=model_input
|
|
148
|
+
)
|
|
126
149
|
else:
|
|
127
150
|
model_input = self._build_input(
|
|
128
|
-
|
|
129
|
-
system_message=input.get("system_message", "")
|
|
151
|
+
data=input
|
|
130
152
|
)
|
|
131
153
|
|
|
132
154
|
tokenized_input = self._tokenize(model_input)
|
|
@@ -175,14 +197,17 @@ class LLaMA3(BaseModel):
|
|
|
175
197
|
generation_params = create_generation_params(params)
|
|
176
198
|
self.model.generation_config = generation_params
|
|
177
199
|
|
|
200
|
+
model_input = None
|
|
178
201
|
if isinstance(input, str):
|
|
179
|
-
model_input = self.
|
|
202
|
+
model_input = self.build_input(
|
|
180
203
|
input_text=input
|
|
181
204
|
)
|
|
205
|
+
model_input = self._build_input(
|
|
206
|
+
data=model_input
|
|
207
|
+
)
|
|
182
208
|
else:
|
|
183
209
|
model_input = self._build_input(
|
|
184
|
-
|
|
185
|
-
system_message=input.get("system_message")
|
|
210
|
+
data=input
|
|
186
211
|
)
|
|
187
212
|
|
|
188
213
|
tokenized_input = self._tokenize(model_input)
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from functools import partial
|
|
3
|
+
from time import time
|
|
4
|
+
from typing import Iterator, Literal, TypedDict, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
|
|
8
|
+
StoppingCriteriaList, TextIteratorStreamer, Trainer,
|
|
9
|
+
TrainingArguments)
|
|
10
|
+
from transformers.models.llama4 import Llama4ForCausalLM
|
|
11
|
+
|
|
12
|
+
from llmflowstack.base.base import BaseModel
|
|
13
|
+
from llmflowstack.callbacks.log_collector import LogCollectorCallback
|
|
14
|
+
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
15
|
+
from llmflowstack.schemas.params import GenerationParams, TrainParams
|
|
16
|
+
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
17
|
+
from llmflowstack.utils.generation_utils import create_generation_params
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LLaMA4Input(TypedDict):
|
|
21
|
+
input_text: str
|
|
22
|
+
expected_answer: str | None
|
|
23
|
+
system_message: str | None
|
|
24
|
+
image_paths: list[str] | None
|
|
25
|
+
|
|
26
|
+
class LLaMA4(BaseModel):
|
|
27
|
+
model: Llama4ForCausalLM | None = None
|
|
28
|
+
question_fields = ["input_text", "system_message"]
|
|
29
|
+
answer_fields = ["expected_answer"]
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
checkpoint: str | None = None,
|
|
34
|
+
seed: int | None = None,
|
|
35
|
+
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
36
|
+
) -> None:
|
|
37
|
+
return super().__init__(
|
|
38
|
+
checkpoint=checkpoint,
|
|
39
|
+
quantization=None,
|
|
40
|
+
seed=seed,
|
|
41
|
+
log_level=log_level
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def _set_generation_stopping_tokens(
|
|
45
|
+
self,
|
|
46
|
+
tokens: list[int]
|
|
47
|
+
) -> None:
|
|
48
|
+
if not self.tokenizer:
|
|
49
|
+
self._log("Could not set stop tokens - generation may not work...", "WARNING")
|
|
50
|
+
return None
|
|
51
|
+
particular_tokens = self.tokenizer.encode("<|eot|>")
|
|
52
|
+
self.stop_token_ids = tokens + particular_tokens
|
|
53
|
+
|
|
54
|
+
def _load_model(
|
|
55
|
+
self,
|
|
56
|
+
checkpoint: str,
|
|
57
|
+
quantization: None = None
|
|
58
|
+
) -> None:
|
|
59
|
+
self.model = Llama4ForCausalLM.from_pretrained(
|
|
60
|
+
checkpoint,
|
|
61
|
+
dtype="auto",
|
|
62
|
+
device_map="auto",
|
|
63
|
+
attn_implementation="eager"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def load_checkpoint(
|
|
67
|
+
self,
|
|
68
|
+
checkpoint: str,
|
|
69
|
+
quantization: None = None
|
|
70
|
+
) -> None:
|
|
71
|
+
return super().load_checkpoint(checkpoint, quantization)
|
|
72
|
+
|
|
73
|
+
def _build_input(
|
|
74
|
+
self,
|
|
75
|
+
data: LLaMA4Input
|
|
76
|
+
) -> str:
|
|
77
|
+
if not self.tokenizer:
|
|
78
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
79
|
+
|
|
80
|
+
system_message = data.get("system_message", "")
|
|
81
|
+
if not system_message:
|
|
82
|
+
system_message = ""
|
|
83
|
+
|
|
84
|
+
if system_message:
|
|
85
|
+
system_message = f"{system_message}\n"
|
|
86
|
+
|
|
87
|
+
expected_answer = data.get("expected_answer")
|
|
88
|
+
answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
|
|
89
|
+
|
|
90
|
+
return (
|
|
91
|
+
f"<start_of_turn>user"
|
|
92
|
+
f"{system_message}\n{data["input_text"]}<end_of_turn>\n"
|
|
93
|
+
f"<start_of_turn>model\n"
|
|
94
|
+
f"{answer}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def build_input(
|
|
98
|
+
self,
|
|
99
|
+
input_text: str,
|
|
100
|
+
system_message: str | None = None,
|
|
101
|
+
expected_answer: str | None = None,
|
|
102
|
+
image_paths: list[str] | None = None
|
|
103
|
+
) -> LLaMA4Input:
|
|
104
|
+
if not self.tokenizer:
|
|
105
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"input_text": input_text,
|
|
109
|
+
"system_message": system_message,
|
|
110
|
+
"expected_answer": expected_answer,
|
|
111
|
+
"image_paths": image_paths
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
def dapt(
|
|
115
|
+
self,
|
|
116
|
+
train_dataset: list,
|
|
117
|
+
params: TrainParams | None = None,
|
|
118
|
+
eval_dataset: list | None = None,
|
|
119
|
+
save_at_end = True,
|
|
120
|
+
save_path: str | None = None
|
|
121
|
+
) -> None:
|
|
122
|
+
if not self.model:
|
|
123
|
+
self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
|
|
124
|
+
return None
|
|
125
|
+
if not self.tokenizer:
|
|
126
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
self._log("Starting DAPT")
|
|
130
|
+
|
|
131
|
+
if self.model_is_quantized:
|
|
132
|
+
self._log("Cannot DAPT a quantized model.", "WARNING")
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
if params is None:
|
|
136
|
+
params = TrainParams()
|
|
137
|
+
|
|
138
|
+
training_arguments = TrainingArguments(
|
|
139
|
+
num_train_epochs=params.epochs,
|
|
140
|
+
learning_rate=params.lr,
|
|
141
|
+
gradient_accumulation_steps=params.gradient_accumulation,
|
|
142
|
+
warmup_ratio=params.warmup_ratio,
|
|
143
|
+
lr_scheduler_type="cosine_with_min_lr",
|
|
144
|
+
lr_scheduler_kwargs={"min_lr_rate": 0.1},
|
|
145
|
+
output_dir=None,
|
|
146
|
+
save_strategy="no",
|
|
147
|
+
logging_steps=params.logging_steps
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if self.seed is not None:
|
|
151
|
+
training_arguments.seed = self.seed
|
|
152
|
+
|
|
153
|
+
processed_train_dataset = self._promptfy_dataset_for_dapt(train_dataset)
|
|
154
|
+
tokenized_train_dataset = self._tokenize_dataset_for_dapt(processed_train_dataset)
|
|
155
|
+
|
|
156
|
+
tokenized_eval_dataset = None
|
|
157
|
+
if eval_dataset:
|
|
158
|
+
processed_eval_dataset = self._promptfy_dataset_for_dapt(eval_dataset)
|
|
159
|
+
tokenized_eval_dataset = self._tokenize_dataset_for_dapt(processed_eval_dataset)
|
|
160
|
+
|
|
161
|
+
log_callback = LogCollectorCallback()
|
|
162
|
+
|
|
163
|
+
trainer = Trainer(
|
|
164
|
+
model=self.model,
|
|
165
|
+
train_dataset=tokenized_train_dataset,
|
|
166
|
+
eval_dataset=tokenized_eval_dataset,
|
|
167
|
+
args=training_arguments,
|
|
168
|
+
callbacks=[log_callback],
|
|
169
|
+
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
trainer.train()
|
|
173
|
+
|
|
174
|
+
if save_at_end and save_path:
|
|
175
|
+
self.save_checkpoint(
|
|
176
|
+
path=save_path
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
self._log("Finished DAPT")
|
|
180
|
+
|
|
181
|
+
def fine_tune(
|
|
182
|
+
self,
|
|
183
|
+
train_dataset: list,
|
|
184
|
+
params: TrainParams | None = None,
|
|
185
|
+
eval_dataset: list | None = None,
|
|
186
|
+
save_at_end = True,
|
|
187
|
+
save_path: str | None = None
|
|
188
|
+
) -> None:
|
|
189
|
+
self._log("Only 'dapt' method is available for this class. Redirecting call to it.", "WARNING")
|
|
190
|
+
return self.dapt(
|
|
191
|
+
train_dataset=train_dataset,
|
|
192
|
+
params=params,
|
|
193
|
+
eval_dataset=eval_dataset,
|
|
194
|
+
save_at_end=save_at_end,
|
|
195
|
+
save_path=save_path
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def generate(
|
|
199
|
+
self,
|
|
200
|
+
input: LLaMA4Input | str,
|
|
201
|
+
params: GenerationParams | None = None
|
|
202
|
+
) -> str | None:
|
|
203
|
+
if self.model is None or self.tokenizer is None:
|
|
204
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
self.model
|
|
208
|
+
|
|
209
|
+
self._log(f"Processing received input...'")
|
|
210
|
+
|
|
211
|
+
if params is None:
|
|
212
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
213
|
+
elif params.max_new_tokens is None:
|
|
214
|
+
params.max_new_tokens = 32768
|
|
215
|
+
|
|
216
|
+
generation_params = create_generation_params(params)
|
|
217
|
+
self.model.generation_config = generation_params
|
|
218
|
+
|
|
219
|
+
if params:
|
|
220
|
+
generation_params = create_generation_params(params)
|
|
221
|
+
self.model.generation_config = generation_params
|
|
222
|
+
|
|
223
|
+
model_input = None
|
|
224
|
+
if isinstance(input, str):
|
|
225
|
+
model_input = self.build_input(
|
|
226
|
+
input_text=input
|
|
227
|
+
)
|
|
228
|
+
model_input = self._build_input(
|
|
229
|
+
data=model_input
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
model_input = self._build_input(
|
|
233
|
+
data=input
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
tokenized_input = self._tokenize(model_input)
|
|
237
|
+
input_ids, attention_mask = tokenized_input
|
|
238
|
+
|
|
239
|
+
self.model.eval()
|
|
240
|
+
self.model.gradient_checkpointing_disable()
|
|
241
|
+
|
|
242
|
+
start = time()
|
|
243
|
+
|
|
244
|
+
with torch.no_grad():
|
|
245
|
+
outputs = self.model.generate(
|
|
246
|
+
input_ids=input_ids,
|
|
247
|
+
attention_mask=attention_mask,
|
|
248
|
+
use_cache=True,
|
|
249
|
+
eos_token_id=None,
|
|
250
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
end = time()
|
|
254
|
+
total_time = end - start
|
|
255
|
+
|
|
256
|
+
self._log(f"Response generated in {total_time:.4f} seconds")
|
|
257
|
+
|
|
258
|
+
response = outputs[0][input_ids.shape[1]:]
|
|
259
|
+
|
|
260
|
+
return self.tokenizer.decode(response, skip_special_tokens=True)
|
|
261
|
+
|
|
262
|
+
def generate_stream(
|
|
263
|
+
self,
|
|
264
|
+
input: LLaMA4Input | str,
|
|
265
|
+
params: GenerationParams | None = None
|
|
266
|
+
) -> Iterator[str]:
|
|
267
|
+
if self.model is None or self.tokenizer is None:
|
|
268
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
269
|
+
if False:
|
|
270
|
+
yield ""
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
if params is None:
|
|
274
|
+
params = GenerationParams(max_new_tokens=32768)
|
|
275
|
+
elif params.max_new_tokens is None:
|
|
276
|
+
params.max_new_tokens = 32768
|
|
277
|
+
|
|
278
|
+
generation_params = create_generation_params(params)
|
|
279
|
+
self.model.generation_config = generation_params
|
|
280
|
+
|
|
281
|
+
model_input = None
|
|
282
|
+
if isinstance(input, str):
|
|
283
|
+
model_input = self.build_input(
|
|
284
|
+
input_text=input
|
|
285
|
+
)
|
|
286
|
+
model_input = self._build_input(
|
|
287
|
+
data=model_input
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
model_input = self._build_input(
|
|
291
|
+
data=input
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
tokenized_input = self._tokenize(model_input)
|
|
295
|
+
input_ids, attention_mask = tokenized_input
|
|
296
|
+
|
|
297
|
+
streamer = TextIteratorStreamer(
|
|
298
|
+
cast(AutoTokenizer, self.tokenizer),
|
|
299
|
+
skip_prompt=True,
|
|
300
|
+
skip_special_tokens=True
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
generate_fn = partial(
|
|
304
|
+
self.model.generate,
|
|
305
|
+
input_ids=input_ids,
|
|
306
|
+
attention_mask=attention_mask,
|
|
307
|
+
use_cache=True,
|
|
308
|
+
eos_token_id=None,
|
|
309
|
+
streamer=streamer,
|
|
310
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
thread = threading.Thread(target=generate_fn)
|
|
314
|
+
thread.start()
|
|
315
|
+
|
|
316
|
+
for new_text in streamer:
|
|
317
|
+
yield new_text
|
llmflowstack-1.0.2/llmflowstack/models/Gemma.py → llmflowstack-1.1.1/llmflowstack/models/MedGemma.py
RENAMED
|
@@ -17,17 +17,31 @@ from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
|
17
17
|
from llmflowstack.utils.generation_utils import create_generation_params
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class MedGemmaInput(TypedDict):
|
|
21
21
|
input_text: str
|
|
22
22
|
expected_answer: str | None
|
|
23
23
|
system_message: str | None
|
|
24
24
|
|
|
25
|
-
class
|
|
25
|
+
class MedGemma(BaseModel):
|
|
26
26
|
model: Gemma3ForCausalLM | None = None
|
|
27
27
|
can_think = False
|
|
28
28
|
question_fields = ["input_text", "system_message"]
|
|
29
29
|
answer_fields = ["expected_answer"]
|
|
30
30
|
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
checkpoint: str | None = None,
|
|
34
|
+
quantization: Literal["4bit"] | None = None,
|
|
35
|
+
seed: int | None = None,
|
|
36
|
+
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
37
|
+
) -> None:
|
|
38
|
+
return super().__init__(
|
|
39
|
+
checkpoint=checkpoint,
|
|
40
|
+
quantization=quantization,
|
|
41
|
+
seed=seed,
|
|
42
|
+
log_level=log_level
|
|
43
|
+
)
|
|
44
|
+
|
|
31
45
|
def _set_generation_stopping_tokens(
|
|
32
46
|
self,
|
|
33
47
|
tokens: list[int]
|
|
@@ -41,17 +55,13 @@ class Gemma(BaseModel):
|
|
|
41
55
|
def _load_model(
|
|
42
56
|
self,
|
|
43
57
|
checkpoint: str,
|
|
44
|
-
quantization: Literal["
|
|
58
|
+
quantization: Literal["4bit"] | None = None
|
|
45
59
|
) -> None:
|
|
46
60
|
quantization_config = None
|
|
47
61
|
if quantization == "4bit":
|
|
48
62
|
quantization_config = BitsAndBytesConfig(
|
|
49
63
|
load_in_4bit=True
|
|
50
64
|
)
|
|
51
|
-
if quantization == "8bit":
|
|
52
|
-
quantization_config = BitsAndBytesConfig(
|
|
53
|
-
load_in_8bit=True
|
|
54
|
-
)
|
|
55
65
|
|
|
56
66
|
self.model = Gemma3ForCausalLM.from_pretrained(
|
|
57
67
|
checkpoint,
|
|
@@ -60,16 +70,22 @@ class Gemma(BaseModel):
|
|
|
60
70
|
device_map="auto",
|
|
61
71
|
attn_implementation="eager"
|
|
62
72
|
)
|
|
73
|
+
|
|
74
|
+
def load_checkpoint(
|
|
75
|
+
self,
|
|
76
|
+
checkpoint: str,
|
|
77
|
+
quantization: Literal["4bit"] | None = None
|
|
78
|
+
) -> None:
|
|
79
|
+
return super().load_checkpoint(checkpoint, quantization)
|
|
63
80
|
|
|
64
81
|
def _build_input(
|
|
65
82
|
self,
|
|
66
|
-
|
|
67
|
-
expected_answer: str | None = None,
|
|
68
|
-
system_message: str | None = None
|
|
83
|
+
data: MedGemmaInput
|
|
69
84
|
) -> str:
|
|
70
85
|
if not self.tokenizer:
|
|
71
86
|
raise MissingEssentialProp("Could not find tokenizer.")
|
|
72
87
|
|
|
88
|
+
system_message = data.get("system_message", "")
|
|
73
89
|
if not system_message:
|
|
74
90
|
system_message = ""
|
|
75
91
|
if self.can_think:
|
|
@@ -78,11 +94,12 @@ class Gemma(BaseModel):
|
|
|
78
94
|
if system_message:
|
|
79
95
|
system_message = f"{system_message}\n"
|
|
80
96
|
|
|
97
|
+
expected_answer = data.get("expected_answer")
|
|
81
98
|
answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
|
|
82
99
|
|
|
83
|
-
return
|
|
100
|
+
return (
|
|
84
101
|
f"<start_of_turn>user"
|
|
85
|
-
f"{system_message}\n{input_text}<end_of_turn>\n"
|
|
102
|
+
f"{system_message}\n{data["input_text"]}<end_of_turn>\n"
|
|
86
103
|
f"<start_of_turn>model\n"
|
|
87
104
|
f"{answer}"
|
|
88
105
|
)
|
|
@@ -92,7 +109,7 @@ class Gemma(BaseModel):
|
|
|
92
109
|
input_text: str,
|
|
93
110
|
expected_answer: str | None = None,
|
|
94
111
|
system_message: str | None = None
|
|
95
|
-
) ->
|
|
112
|
+
) -> MedGemmaInput:
|
|
96
113
|
if not self.tokenizer:
|
|
97
114
|
raise MissingEssentialProp("Could not find tokenizer.")
|
|
98
115
|
|
|
@@ -107,7 +124,7 @@ class Gemma(BaseModel):
|
|
|
107
124
|
|
|
108
125
|
def generate(
|
|
109
126
|
self,
|
|
110
|
-
input:
|
|
127
|
+
input: MedGemmaInput | str,
|
|
111
128
|
params: GenerationParams | None = None,
|
|
112
129
|
) -> str | None:
|
|
113
130
|
if self.model is None or self.tokenizer is None:
|
|
@@ -126,17 +143,18 @@ class Gemma(BaseModel):
|
|
|
126
143
|
|
|
127
144
|
model_input = None
|
|
128
145
|
if isinstance(input, str):
|
|
129
|
-
model_input = self.
|
|
146
|
+
model_input = self.build_input(
|
|
130
147
|
input_text=input
|
|
131
148
|
)
|
|
149
|
+
model_input = self._build_input(
|
|
150
|
+
data=model_input
|
|
151
|
+
)
|
|
132
152
|
else:
|
|
133
153
|
model_input = self._build_input(
|
|
134
|
-
|
|
135
|
-
system_message=input["system_message"]
|
|
154
|
+
data=input
|
|
136
155
|
)
|
|
137
156
|
|
|
138
157
|
tokenized_input = self._tokenize(model_input)
|
|
139
|
-
|
|
140
158
|
input_ids, attention_mask = tokenized_input
|
|
141
159
|
|
|
142
160
|
self.model.eval()
|
|
@@ -174,7 +192,7 @@ class Gemma(BaseModel):
|
|
|
174
192
|
|
|
175
193
|
def generate_stream(
|
|
176
194
|
self,
|
|
177
|
-
input:
|
|
195
|
+
input: MedGemmaInput | str,
|
|
178
196
|
params: GenerationParams | None = None
|
|
179
197
|
) -> Iterator[str]:
|
|
180
198
|
if self.model is None or self.tokenizer is None:
|
|
@@ -191,14 +209,17 @@ class Gemma(BaseModel):
|
|
|
191
209
|
generation_params = create_generation_params(params)
|
|
192
210
|
self.model.generation_config = generation_params
|
|
193
211
|
|
|
212
|
+
model_input = None
|
|
194
213
|
if isinstance(input, str):
|
|
195
|
-
model_input = self.
|
|
214
|
+
model_input = self.build_input(
|
|
196
215
|
input_text=input
|
|
197
216
|
)
|
|
217
|
+
model_input = self._build_input(
|
|
218
|
+
data=model_input
|
|
219
|
+
)
|
|
198
220
|
else:
|
|
199
221
|
model_input = self._build_input(
|
|
200
|
-
|
|
201
|
-
system_message=input.get("system_message")
|
|
222
|
+
data=input
|
|
202
223
|
)
|
|
203
224
|
|
|
204
225
|
tokenized_input = self._tokenize(model_input)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
|
|
3
|
+
import chromadb
|
|
4
|
+
import chromadb.config
|
|
3
5
|
from langchain_chroma import Chroma
|
|
4
6
|
from langchain_core.documents import Document
|
|
5
7
|
from langchain_core.embeddings import Embeddings
|
|
@@ -38,10 +40,15 @@ class RAGPipeline:
|
|
|
38
40
|
|
|
39
41
|
self.encoder = SentenceTransformer(checkpoint, trust_remote_code=True)
|
|
40
42
|
|
|
43
|
+
client_settings = chromadb.config.Settings(
|
|
44
|
+
anonymized_telemetry=False
|
|
45
|
+
)
|
|
46
|
+
|
|
41
47
|
self.vector_store = Chroma(
|
|
42
48
|
collection_name=collection_name,
|
|
43
49
|
embedding_function=EncoderWrapper(self.encoder),
|
|
44
|
-
persist_directory=persist_directory
|
|
50
|
+
persist_directory=persist_directory,
|
|
51
|
+
client_settings=client_settings
|
|
45
52
|
)
|
|
46
53
|
|
|
47
54
|
self.splitter = RecursiveCharacterTextSplitter(
|
|
@@ -4,11 +4,11 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "llmflowstack"
|
|
7
|
-
version = "1.
|
|
7
|
+
version = "1.1.1"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name = "Gustavo Henrique Ferreira Cruz", email = "gustavohferreiracruz@gmail.com" }
|
|
10
10
|
]
|
|
11
|
-
description = "LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components."
|
|
11
|
+
description = "LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma, ...). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components."
|
|
12
12
|
readme = "README.md"
|
|
13
13
|
requires-python = ">=3.12"
|
|
14
14
|
license = {text = "MIT"}
|
|
@@ -22,6 +22,7 @@ dependencies = [
|
|
|
22
22
|
"datasets",
|
|
23
23
|
"evaluate",
|
|
24
24
|
"huggingface-hub",
|
|
25
|
+
"kernels",
|
|
25
26
|
"langchain-chroma",
|
|
26
27
|
"langchain_community",
|
|
27
28
|
"nltk",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|