llmflowstack 1.0.2__tar.gz → 1.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/PKG-INFO +19 -5
  2. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/README.md +16 -3
  3. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/__init__.py +7 -3
  4. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/base/base.py +15 -8
  5. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/models/GPT_OSS.py +42 -23
  6. llmflowstack-1.1.1/llmflowstack/models/Gemma.py +319 -0
  7. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/models/LLaMA3.py +39 -14
  8. llmflowstack-1.1.1/llmflowstack/models/LLaMA4.py +317 -0
  9. llmflowstack-1.0.2/llmflowstack/models/Gemma.py → llmflowstack-1.1.1/llmflowstack/models/MedGemma.py +43 -22
  10. llmflowstack-1.1.1/llmflowstack/models/__init__.py +13 -0
  11. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/rag/pipeline.py +8 -1
  12. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/pyproject.toml +3 -2
  13. llmflowstack-1.0.2/llmflowstack/models/__init__.py +0 -9
  14. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/.github/workflows/python-publish.yml +0 -0
  15. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/.gitignore +0 -0
  16. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/LICENSE +0 -0
  17. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/base/__init__.py +0 -0
  18. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/callbacks/__init__.py +0 -0
  19. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/callbacks/log_collector.py +0 -0
  20. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/callbacks/stop_on_token.py +0 -0
  21. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/rag/__iinit__.py +0 -0
  22. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/schemas/__init__.py +0 -0
  23. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/schemas/params.py +0 -0
  24. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/__init__.py +0 -0
  25. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/evaluation_methods.py +0 -0
  26. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/exceptions.py +0 -0
  27. {llmflowstack-1.0.2 → llmflowstack-1.1.1}/llmflowstack/utils/generation_utils.py +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: llmflowstack
3
- Version: 1.0.2
4
- Summary: LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components.
3
+ Version: 1.1.1
4
+ Summary: LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma, ...). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components.
5
5
  Author-email: Gustavo Henrique Ferreira Cruz <gustavohferreiracruz@gmail.com>
6
6
  License: MIT
7
7
  License-File: LICENSE
@@ -14,6 +14,7 @@ Requires-Dist: colorama
14
14
  Requires-Dist: datasets
15
15
  Requires-Dist: evaluate
16
16
  Requires-Dist: huggingface-hub
17
+ Requires-Dist: kernels
17
18
  Requires-Dist: langchain-chroma
18
19
  Requires-Dist: langchain-community
19
20
  Requires-Dist: nltk
@@ -56,18 +57,31 @@ This framework is designed to provide flexibility when working with different op
56
57
 
57
58
  - [`GPT-OSS 20B`](https://huggingface.co/openai/gpt-oss-20b)
58
59
  - [`GPT-OSS 120B`](https://huggingface.co/openai/gpt-oss-120b)
60
+ > Fine-Tuning, DAPT and Inference Available
59
61
 
60
- - **LLaMA**
62
+ - **LLaMA 3**
61
63
 
62
64
  - [`LLaMA 3.1 8B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
63
65
  - [`LLaMA 3.1 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct)
64
66
  - [`LLaMA 3.3 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)
65
67
  - [`LLaMA 3.3 405B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct)
68
+ > Fine-Tuning, DAPT and Inference Available
69
+
70
+ - **LLaMA 4**
71
+
72
+ - [`LLaMA 4 Scout - Instruct`](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
73
+ > DAPT and Inference Available
66
74
 
67
75
  - **Gemma**
68
- - [`MedGemma 27B Text - It`](https://huggingface.co/google/medgemma-27b-text-it)
69
76
 
70
- > Compatibility includes both inference and training (Domain-Adaptive Pre-Training — DAPT — and Supervised Fine-Tuning)
77
+ - [`Gemma 3 27B - Instruct`](https://huggingface.co/google/gemma-3-27b-it)
78
+ > DAPT and Inference Available
79
+
80
+ - **MedGemma**
81
+ - [`MedGemma 27B Text - Instruct`](https://huggingface.co/google/medgemma-27b-text-it)
82
+ > Fine-Tuning, DAPT and Inference Available
83
+
84
+ > Other architectures based on those **may** function correctly.
71
85
 
72
86
  ---
73
87
 
@@ -20,18 +20,31 @@ This framework is designed to provide flexibility when working with different op
20
20
 
21
21
  - [`GPT-OSS 20B`](https://huggingface.co/openai/gpt-oss-20b)
22
22
  - [`GPT-OSS 120B`](https://huggingface.co/openai/gpt-oss-120b)
23
+ > Fine-Tuning, DAPT and Inference Available
23
24
 
24
- - **LLaMA**
25
+ - **LLaMA 3**
25
26
 
26
27
  - [`LLaMA 3.1 8B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
27
28
  - [`LLaMA 3.1 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct)
28
29
  - [`LLaMA 3.3 70B - Instruct`](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)
29
30
  - [`LLaMA 3.3 405B - Instruct`](https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct)
31
+ > Fine-Tuning, DAPT and Inference Available
32
+
33
+ - **LLaMA 4**
34
+
35
+ - [`LLaMA 4 Scout - Instruct`](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
36
+ > DAPT and Inference Available
30
37
 
31
38
  - **Gemma**
32
- - [`MedGemma 27B Text - It`](https://huggingface.co/google/medgemma-27b-text-it)
33
39
 
34
- > Compatibility includes both inference and training (Domain-Adaptive Pre-Training — DAPT — and Supervised Fine-Tuning)
40
+ - [`Gemma 3 27B - Instruct`](https://huggingface.co/google/gemma-3-27b-it)
41
+ > DAPT and Inference Available
42
+
43
+ - **MedGemma**
44
+ - [`MedGemma 27B Text - Instruct`](https://huggingface.co/google/medgemma-27b-text-it)
45
+ > Fine-Tuning, DAPT and Inference Available
46
+
47
+ > Other architectures based on those **may** function correctly.
35
48
 
36
49
  ---
37
50
 
@@ -1,15 +1,19 @@
1
- from .models.Gemma import Gemma
1
+ from .models.Gemma import Gemma3
2
2
  from .models.GPT_OSS import GPT_OSS
3
3
  from .models.LLaMA3 import LLaMA3
4
+ from .models.LLaMA4 import LLaMA4
5
+ from .models.MedGemma import MedGemma
4
6
  from .rag.pipeline import RAGPipeline
5
7
  from .schemas.params import (GenerationBeamsParams, GenerationParams,
6
8
  GenerationSampleParams, TrainParams)
7
9
  from .utils.evaluation_methods import text_evaluation
8
10
 
9
11
  __all__ = [
10
- "Gemma",
11
- "LLaMA3",
12
+ "Gemma3",
12
13
  "GPT_OSS",
14
+ "LLaMA3",
15
+ "LLaMA4",
16
+ "MedGemma",
13
17
  "RAGPipeline",
14
18
  "GenerationBeamsParams",
15
19
  "GenerationParams",
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  import json
2
3
  import logging
3
4
  import os
@@ -35,7 +36,7 @@ class BaseModel(ABC):
35
36
  def __init__(
36
37
  self,
37
38
  checkpoint: str | None = None,
38
- quantization: Literal["8bit", "4bit"] | bool | None = None,
39
+ quantization: Literal["4bit", "8bit"] | bool | None = None,
39
40
  seed: int | None = None,
40
41
  log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
41
42
  ) -> None:
@@ -64,7 +65,8 @@ class BaseModel(ABC):
64
65
  def _load_model(
65
66
  self,
66
67
  checkpoint: str,
67
- quantization: Literal["8bit", "4bit"] | bool | None = None
68
+ *args: Any,
69
+ **kwargs: Any
68
70
  ) -> None:
69
71
  pass
70
72
 
@@ -79,7 +81,7 @@ class BaseModel(ABC):
79
81
  def load_checkpoint(
80
82
  self,
81
83
  checkpoint: str,
82
- quantization: Literal["8bit", "4bit"] | bool | None = None
84
+ quantization: Any
83
85
  ) -> None:
84
86
  if self.model:
85
87
  self._log("A model is already loaded. Attempting to reset it.", "WARNING")
@@ -223,7 +225,7 @@ class BaseModel(ABC):
223
225
  self,
224
226
  *args: Any,
225
227
  **kwargs: Any
226
- ) -> str:
228
+ ) -> str | BatchEncoding:
227
229
  pass
228
230
 
229
231
  def _tokenize(
@@ -282,7 +284,7 @@ class BaseModel(ABC):
282
284
  output = []
283
285
  for data in dataset:
284
286
  complete_input = self._build_input(
285
- **{field: data.get(field) for field in self.question_fields + self.answer_fields}
287
+ data
286
288
  )
287
289
  output.append(complete_input)
288
290
 
@@ -403,13 +405,16 @@ class BaseModel(ABC):
403
405
  def _build_input_for_fine_tune(
404
406
  self,
405
407
  input: dict
406
- ) -> dict[Literal["partial", "complete"], str]:
408
+ ) -> dict[Literal["partial", "complete"], str | BatchEncoding]:
407
409
  if not self.tokenizer:
408
410
  raise MissingEssentialProp("Could not find tokenizer.")
409
411
 
410
- partial = self._build_input(**{k: input[k] for k in self.question_fields if k in input})
412
+ partial = self._build_input({
413
+ **input,
414
+ "expected_answer": None
415
+ })
411
416
 
412
- complete = self._build_input(**{k: input[k] for k in self.question_fields + self.answer_fields if k in input})
417
+ complete = self._build_input(input)
413
418
 
414
419
  return {
415
420
  "partial": partial,
@@ -508,6 +513,8 @@ class BaseModel(ABC):
508
513
  try:
509
514
  self._log("Trying to reset model...")
510
515
  del self.model
516
+ gc.collect()
517
+ torch.cuda.empty_cache()
511
518
  self.model = None
512
519
  self.model_is_quantized = None
513
520
  self.process_id = None
@@ -2,7 +2,7 @@ import textwrap
2
2
  import threading
3
3
  from functools import partial
4
4
  from time import time
5
- from typing import Any, Generator, Iterator, Literal, TypedDict, cast
5
+ from typing import Iterator, Literal, TypedDict, cast
6
6
 
7
7
  import torch
8
8
  from openai_harmony import HarmonyEncodingName, load_harmony_encoding
@@ -32,6 +32,20 @@ class GPT_OSS(BaseModel):
32
32
  question_fields = ["input_text", "developer_message", "system_message"]
33
33
  answer_fields = ["expected_answer", "reasoning_message"]
34
34
 
35
+ def __init__(
36
+ self,
37
+ checkpoint: str | None = None,
38
+ quantization: bool | None = None,
39
+ seed: int | None = None,
40
+ log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
41
+ ) -> None:
42
+ return super().__init__(
43
+ checkpoint=checkpoint,
44
+ quantization=quantization,
45
+ seed=seed,
46
+ log_level=log_level
47
+ )
48
+
35
49
  def _set_generation_stopping_tokens(
36
50
  self,
37
51
  tokens: list[int]
@@ -46,10 +60,9 @@ class GPT_OSS(BaseModel):
46
60
  def _load_model(
47
61
  self,
48
62
  checkpoint: str,
49
- quantization: Literal["8bit", "4bit"] | bool | None = False
63
+ quantization: bool | None = False
50
64
  ) -> None:
51
65
  if quantization:
52
- self.model_is_quantized = True
53
66
  quantization_config = Mxfp4Config(dequantize=False)
54
67
  else:
55
68
  quantization_config = Mxfp4Config(dequantize=True)
@@ -70,37 +83,43 @@ class GPT_OSS(BaseModel):
70
83
  device_map="auto",
71
84
  attn_implementation="eager"
72
85
  )
86
+
87
+ def load_checkpoint(
88
+ self,
89
+ checkpoint: str,
90
+ quantization: bool | None = None
91
+ ) -> None:
92
+ return super().load_checkpoint(checkpoint, quantization)
73
93
 
74
94
  def _build_input(
75
95
  self,
76
- input_text: str,
77
- expected_answer: str | None = None,
78
- system_message: str | None = None,
79
- reasoning_level: Literal["Low", "Medium", "High"] | None = None,
80
- reasoning_message: str | None = None,
81
- developer_message: str | None = None
96
+ data: GPTOSSInput
82
97
  ) -> str:
83
98
  if not self.tokenizer:
84
99
  raise MissingEssentialProp("Could not find tokenizer.")
85
100
 
86
- reasoning = reasoning_level
101
+ reasoning = data.get("reasoning_level")
87
102
  if reasoning is None:
88
103
  reasoning = self.reasoning_level
89
104
 
90
- system_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\n\nReasoning: {reasoning}\n\n{system_message or ""}# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>"
105
+ system_message = data.get("system_message", "")
106
+ system_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\n\nReasoning: {reasoning}\n\n{system_message}# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>"
91
107
 
92
108
  developer_text = ""
109
+ developer_message = data.get("developer_message", "")
93
110
  if developer_message:
94
- developer_text = f"<|start|>developer<|message|># Instructions\n\n{developer_message or ""}<|end|>"
111
+ developer_text = f"<|start|>developer<|message|># Instructions\n\n{developer_message}<|end|>"
95
112
 
96
113
  assistant_text = ""
114
+ reasoning_message = data.get("reasoning_message", "")
97
115
  if reasoning_message:
98
116
  assistant_text += f"<|start|>assistant<|channel|>analysis<|message|>{reasoning_message}<|end|>"
99
117
 
118
+ expected_answer = data.get("expected_answer", "")
100
119
  if expected_answer:
101
120
  assistant_text += f"<|start|>assistant<|channel|>final<|message|>{expected_answer}<|return|>"
102
121
 
103
- return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{input_text}<|end|>{assistant_text}""")
122
+ return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{data["input_text"]}<|end|>{assistant_text}""")
104
123
 
105
124
  def build_input(
106
125
  self,
@@ -150,15 +169,15 @@ class GPT_OSS(BaseModel):
150
169
 
151
170
  model_input = None
152
171
  if isinstance(input, str):
153
- model_input = self._build_input(
172
+ model_input = self.build_input(
154
173
  input_text=input
155
174
  )
175
+ model_input = self._build_input(
176
+ data=model_input
177
+ )
156
178
  else:
157
179
  model_input = self._build_input(
158
- input_text=input["input_text"],
159
- developer_message=input.get("developer_message", None),
160
- system_message=input.get("system_message", None),
161
- reasoning_level=input.get("reasoning_level", None)
180
+ data=input
162
181
  )
163
182
 
164
183
  tokenized_input = self._tokenize(model_input)
@@ -217,15 +236,15 @@ class GPT_OSS(BaseModel):
217
236
  self.model.generation_config = generation_params
218
237
 
219
238
  if isinstance(input, str):
220
- model_input = self._build_input(
239
+ model_input = self.build_input(
221
240
  input_text=input
222
241
  )
242
+ model_input = self._build_input(
243
+ data=model_input
244
+ )
223
245
  else:
224
246
  model_input = self._build_input(
225
- input_text=input["input_text"],
226
- developer_message=input.get("developer_message"),
227
- system_message=input.get("system_message"),
228
- reasoning_level=input.get("reasoning_level")
247
+ data=input
229
248
  )
230
249
 
231
250
  tokenized_input = self._tokenize(model_input)
@@ -0,0 +1,319 @@
1
+ import threading
2
+ from functools import partial
3
+ from time import time
4
+ from typing import Iterator, Literal, TypedDict, cast
5
+
6
+ import torch
7
+ from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
8
+ StoppingCriteriaList, TextIteratorStreamer, Trainer,
9
+ TrainingArguments)
10
+ from transformers.models.gemma3 import Gemma3ForCausalLM
11
+ from transformers.utils.quantization_config import BitsAndBytesConfig
12
+
13
+ from llmflowstack.base.base import BaseModel
14
+ from llmflowstack.callbacks.log_collector import LogCollectorCallback
15
+ from llmflowstack.callbacks.stop_on_token import StopOnToken
16
+ from llmflowstack.schemas.params import GenerationParams, TrainParams
17
+ from llmflowstack.utils.exceptions import MissingEssentialProp
18
+ from llmflowstack.utils.generation_utils import create_generation_params
19
+
20
+
21
+ class Gemma3Input(TypedDict):
22
+ input_text: str
23
+ expected_answer: str | None
24
+ system_message: str | None
25
+ image_paths: list[str] | None
26
+
27
+ class Gemma3(BaseModel):
28
+ model: Gemma3ForCausalLM | None = None
29
+ question_fields = ["input_text", "system_message"]
30
+ answer_fields = ["expected_answer"]
31
+
32
+ def __init__(
33
+ self,
34
+ checkpoint: str | None = None,
35
+ quantization: Literal["4bit"] | None = None,
36
+ seed: int | None = None,
37
+ log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
38
+ ) -> None:
39
+ return super().__init__(
40
+ checkpoint=checkpoint,
41
+ quantization=quantization,
42
+ seed=seed,
43
+ log_level=log_level
44
+ )
45
+
46
+ def _set_generation_stopping_tokens(
47
+ self,
48
+ tokens: list[int]
49
+ ) -> None:
50
+ if not self.tokenizer:
51
+ self._log("Could not set stop tokens - generation may not work...", "WARNING")
52
+ return None
53
+ particular_tokens = self.tokenizer.encode("<end_of_turn>")
54
+ self.stop_token_ids = tokens + particular_tokens
55
+
56
+ def _load_model(
57
+ self,
58
+ checkpoint: str,
59
+ quantization: Literal["4bit"] | None = None
60
+ ) -> None:
61
+ quantization_config = None
62
+ if quantization == "4bit":
63
+ quantization_config = BitsAndBytesConfig(
64
+ load_in_4bit=True
65
+ )
66
+
67
+ self.model = Gemma3ForCausalLM.from_pretrained(
68
+ checkpoint,
69
+ quantization_config=quantization_config,
70
+ dtype="auto",
71
+ device_map="auto",
72
+ attn_implementation="eager"
73
+ )
74
+
75
+ def load_checkpoint(
76
+ self,
77
+ checkpoint: str,
78
+ quantization: Literal['4bit'] | None = None
79
+ ) -> None:
80
+ return super().load_checkpoint(checkpoint, quantization)
81
+
82
+ def _build_input(
83
+ self,
84
+ data: Gemma3Input
85
+ ) -> str:
86
+ if not self.tokenizer:
87
+ raise MissingEssentialProp("Could not find tokenizer.")
88
+
89
+ system_message = data.get("system_message", "")
90
+ if not system_message:
91
+ system_message = ""
92
+
93
+ if system_message:
94
+ system_message = f"{system_message}\n"
95
+
96
+ expected_answer = data.get("expected_answer")
97
+ answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
98
+
99
+ return (
100
+ f"<start_of_turn>user"
101
+ f"{system_message}\n{data["input_text"]}<end_of_turn>\n"
102
+ f"<start_of_turn>model\n"
103
+ f"{answer}"
104
+ )
105
+
106
+ def build_input(
107
+ self,
108
+ input_text: str,
109
+ system_message: str | None = None,
110
+ expected_answer: str | None = None,
111
+ image_paths: list[str] | None = None
112
+ ) -> Gemma3Input:
113
+ if not self.tokenizer:
114
+ raise MissingEssentialProp("Could not find tokenizer.")
115
+
116
+ return {
117
+ "input_text": input_text,
118
+ "system_message": system_message,
119
+ "expected_answer": expected_answer,
120
+ "image_paths": image_paths
121
+ }
122
+
123
+ def dapt(
124
+ self,
125
+ train_dataset: list,
126
+ params: TrainParams | None = None,
127
+ eval_dataset: list | None = None,
128
+ save_at_end = True,
129
+ save_path: str | None = None
130
+ ) -> None:
131
+ if not self.model:
132
+ self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
133
+ return None
134
+ if not self.tokenizer:
135
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
136
+ return None
137
+
138
+ self._log("Starting Training")
139
+
140
+ if self.model_is_quantized:
141
+ self._log("Cannot traub a quantized model.", "WARNING")
142
+ return None
143
+
144
+ if params is None:
145
+ params = TrainParams()
146
+
147
+ training_arguments = TrainingArguments(
148
+ num_train_epochs=params.epochs,
149
+ learning_rate=params.lr,
150
+ gradient_accumulation_steps=params.gradient_accumulation,
151
+ warmup_ratio=params.warmup_ratio,
152
+ lr_scheduler_type="cosine_with_min_lr",
153
+ lr_scheduler_kwargs={"min_lr_rate": 0.1},
154
+ output_dir=None,
155
+ save_strategy="no",
156
+ logging_steps=params.logging_steps
157
+ )
158
+
159
+ if self.seed is not None:
160
+ training_arguments.seed = self.seed
161
+
162
+ processed_train_dataset = self._promptfy_dataset_for_dapt(train_dataset)
163
+ tokenized_train_dataset = self._tokenize_dataset_for_dapt(processed_train_dataset)
164
+
165
+ tokenized_eval_dataset = None
166
+ if eval_dataset:
167
+ processed_eval_dataset = self._promptfy_dataset_for_dapt(eval_dataset)
168
+ tokenized_eval_dataset = self._tokenize_dataset_for_dapt(processed_eval_dataset)
169
+
170
+ log_callback = LogCollectorCallback()
171
+
172
+ trainer = Trainer(
173
+ model=self.model,
174
+ train_dataset=tokenized_train_dataset,
175
+ eval_dataset=tokenized_eval_dataset,
176
+ args=training_arguments,
177
+ callbacks=[log_callback],
178
+ data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
179
+ )
180
+
181
+ trainer.train()
182
+
183
+ if save_at_end and save_path:
184
+ self.save_checkpoint(
185
+ path=save_path
186
+ )
187
+
188
+ self._log("Finished Training")
189
+
190
+ def fine_tune(
191
+ self,
192
+ train_dataset: list,
193
+ params: TrainParams | None = None,
194
+ eval_dataset: list | None = None,
195
+ save_at_end = True,
196
+ save_path: str | None = None
197
+ ) -> None:
198
+ self._log("Only 'dapt' method is available for this class. Redirecting call to it.", "WARNING")
199
+ return self.dapt(
200
+ train_dataset=train_dataset,
201
+ params=params,
202
+ eval_dataset=eval_dataset,
203
+ save_at_end=save_at_end,
204
+ save_path=save_path
205
+ )
206
+
207
+ def generate(
208
+ self,
209
+ input: Gemma3Input | str,
210
+ params: GenerationParams | None = None,
211
+ ) -> str | None:
212
+ if self.model is None or self.tokenizer is None:
213
+ self._log("Model or Tokenizer missing", "WARNING")
214
+ return None
215
+
216
+ self._log(f"Processing received input...'")
217
+
218
+ if params is None:
219
+ params = GenerationParams(max_new_tokens=32768)
220
+ elif params.max_new_tokens is None:
221
+ params.max_new_tokens = 32768
222
+
223
+ generation_params = create_generation_params(params)
224
+ self.model.generation_config = generation_params
225
+
226
+ model_input = None
227
+ if isinstance(input, str):
228
+ model_input = self.build_input(
229
+ input_text=input
230
+ )
231
+ model_input = self._build_input(
232
+ data=model_input
233
+ )
234
+ else:
235
+ model_input = self._build_input(
236
+ data=input
237
+ )
238
+
239
+ tokenized_input = self._tokenize(model_input)
240
+ input_ids, attention_mask = tokenized_input
241
+
242
+ self.model.eval()
243
+ self.model.gradient_checkpointing_disable()
244
+ start = time()
245
+
246
+ with torch.no_grad():
247
+ outputs = self.model.generate(
248
+ input_ids=input_ids,
249
+ attention_mask=attention_mask,
250
+ use_cache=True,
251
+ eos_token_id=None,
252
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
253
+ )
254
+
255
+ end = time()
256
+ total_time = end - start
257
+
258
+ self._log(f"Response generated in {total_time:.4f} seconds")
259
+
260
+ response = outputs[0][input_ids.shape[1]:]
261
+
262
+ return self.tokenizer.decode(response, skip_special_tokens=True)
263
+
264
+ def generate_stream(
265
+ self,
266
+ input: Gemma3Input | str,
267
+ params: GenerationParams | None = None
268
+ ) -> Iterator[str]:
269
+ if self.model is None or self.tokenizer is None:
270
+ self._log("Model or Tokenizer missing", "WARNING")
271
+ if False:
272
+ yield ""
273
+ return
274
+
275
+ if params is None:
276
+ params = GenerationParams(max_new_tokens=32768)
277
+ elif params.max_new_tokens is None:
278
+ params.max_new_tokens = 32768
279
+
280
+ generation_params = create_generation_params(params)
281
+ self.model.generation_config = generation_params
282
+
283
+ model_input = None
284
+ if isinstance(input, str):
285
+ model_input = self.build_input(
286
+ input_text=input
287
+ )
288
+ model_input = self._build_input(
289
+ data=model_input
290
+ )
291
+ else:
292
+ model_input = self._build_input(
293
+ data=input
294
+ )
295
+
296
+ tokenized_input = self._tokenize(model_input)
297
+ input_ids, attention_mask = tokenized_input
298
+
299
+ streamer = TextIteratorStreamer(
300
+ cast(AutoTokenizer, self.tokenizer),
301
+ skip_prompt=True,
302
+ skip_special_tokens=True
303
+ )
304
+
305
+ generate_fn = partial(
306
+ self.model.generate,
307
+ input_ids=input_ids,
308
+ attention_mask=attention_mask,
309
+ use_cache=True,
310
+ eos_token_id=None,
311
+ streamer=streamer,
312
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
313
+ )
314
+
315
+ thread = threading.Thread(target=generate_fn)
316
+ thread.start()
317
+
318
+ for new_text in streamer:
319
+ yield new_text
@@ -26,6 +26,20 @@ class LLaMA3(BaseModel):
26
26
  question_fields = ["input_text", "system_message"]
27
27
  answer_fields = ["expected_answer"]
28
28
 
29
+ def __init__(
30
+ self,
31
+ checkpoint: str | None = None,
32
+ quantization: Literal["4bit", "8bit"] | None = None,
33
+ seed: int | None = None,
34
+ log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
35
+ ) -> None:
36
+ return super().__init__(
37
+ checkpoint=checkpoint,
38
+ quantization=quantization,
39
+ seed=seed,
40
+ log_level=log_level
41
+ )
42
+
29
43
  def _set_generation_stopping_tokens(
30
44
  self,
31
45
  tokens: list[int]
@@ -39,19 +53,17 @@ class LLaMA3(BaseModel):
39
53
  def _load_model(
40
54
  self,
41
55
  checkpoint: str,
42
- quantization: Literal["8bit", "4bit"] | bool | None = None
56
+ quantization: Literal["4bit", "8bit"] | None = None
43
57
  ) -> None:
44
58
  quantization_config = None
45
59
  if quantization == "4bit":
46
60
  quantization_config = BitsAndBytesConfig(
47
61
  load_in_4bit=True
48
62
  )
49
- self.model_is_quantized = True
50
63
  if quantization == "8bit":
51
64
  quantization_config = BitsAndBytesConfig(
52
65
  load_in_8bit=True
53
66
  )
54
- self.model_is_quantized = True
55
67
 
56
68
  self.model = LlamaForCausalLM.from_pretrained(
57
69
  checkpoint,
@@ -60,21 +72,29 @@ class LLaMA3(BaseModel):
60
72
  device_map="auto",
61
73
  attn_implementation="eager"
62
74
  )
75
+
76
+ def load_checkpoint(
77
+ self,
78
+ checkpoint: str,
79
+ quantization: Literal['4bit', "8bit"] | None = None
80
+ ) -> None:
81
+ return super().load_checkpoint(checkpoint, quantization)
63
82
 
64
83
  def _build_input(
65
84
  self,
66
- input_text: str,
67
- expected_answer: str | None = None,
68
- system_message: str | None = None
85
+ data: LLaMA3Input
69
86
  ) -> str:
70
87
  if not self.tokenizer:
71
88
  raise MissingEssentialProp("Could not find tokenizer.")
72
89
 
90
+ expected_answer = data.get("expected_answer")
73
91
  answer = f"{expected_answer}{self.tokenizer.eos_token}" if expected_answer else ""
74
92
 
93
+ system_message = data.get("system_message", "")
94
+
75
95
  return textwrap.dedent(
76
- f"<|start_header_id|>system<|end_header_id|>{system_message or ""}\n"
77
- f"<|eot_id|><|start_header_id|>user<|end_header_id|>{input_text}\n"
96
+ f"<|start_header_id|>system<|end_header_id|>{system_message}\n"
97
+ f"<|eot_id|><|start_header_id|>user<|end_header_id|>{data["input_text"]}\n"
78
98
  f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>{answer}"
79
99
  )
80
100
 
@@ -120,13 +140,15 @@ class LLaMA3(BaseModel):
120
140
 
121
141
  model_input = None
122
142
  if isinstance(input, str):
123
- model_input = self._build_input(
143
+ model_input = self.build_input(
124
144
  input_text=input
125
145
  )
146
+ model_input = self._build_input(
147
+ data=model_input
148
+ )
126
149
  else:
127
150
  model_input = self._build_input(
128
- input_text=input["input_text"],
129
- system_message=input.get("system_message", "")
151
+ data=input
130
152
  )
131
153
 
132
154
  tokenized_input = self._tokenize(model_input)
@@ -175,14 +197,17 @@ class LLaMA3(BaseModel):
175
197
  generation_params = create_generation_params(params)
176
198
  self.model.generation_config = generation_params
177
199
 
200
+ model_input = None
178
201
  if isinstance(input, str):
179
- model_input = self._build_input(
202
+ model_input = self.build_input(
180
203
  input_text=input
181
204
  )
205
+ model_input = self._build_input(
206
+ data=model_input
207
+ )
182
208
  else:
183
209
  model_input = self._build_input(
184
- input_text=input["input_text"],
185
- system_message=input.get("system_message")
210
+ data=input
186
211
  )
187
212
 
188
213
  tokenized_input = self._tokenize(model_input)
@@ -0,0 +1,317 @@
1
+ import threading
2
+ from functools import partial
3
+ from time import time
4
+ from typing import Iterator, Literal, TypedDict, cast
5
+
6
+ import torch
7
+ from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
8
+ StoppingCriteriaList, TextIteratorStreamer, Trainer,
9
+ TrainingArguments)
10
+ from transformers.models.llama4 import Llama4ForCausalLM
11
+
12
+ from llmflowstack.base.base import BaseModel
13
+ from llmflowstack.callbacks.log_collector import LogCollectorCallback
14
+ from llmflowstack.callbacks.stop_on_token import StopOnToken
15
+ from llmflowstack.schemas.params import GenerationParams, TrainParams
16
+ from llmflowstack.utils.exceptions import MissingEssentialProp
17
+ from llmflowstack.utils.generation_utils import create_generation_params
18
+
19
+
20
+ class LLaMA4Input(TypedDict):
21
+ input_text: str
22
+ expected_answer: str | None
23
+ system_message: str | None
24
+ image_paths: list[str] | None
25
+
26
+ class LLaMA4(BaseModel):
27
+ model: Llama4ForCausalLM | None = None
28
+ question_fields = ["input_text", "system_message"]
29
+ answer_fields = ["expected_answer"]
30
+
31
+ def __init__(
32
+ self,
33
+ checkpoint: str | None = None,
34
+ seed: int | None = None,
35
+ log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
36
+ ) -> None:
37
+ return super().__init__(
38
+ checkpoint=checkpoint,
39
+ quantization=None,
40
+ seed=seed,
41
+ log_level=log_level
42
+ )
43
+
44
+ def _set_generation_stopping_tokens(
45
+ self,
46
+ tokens: list[int]
47
+ ) -> None:
48
+ if not self.tokenizer:
49
+ self._log("Could not set stop tokens - generation may not work...", "WARNING")
50
+ return None
51
+ particular_tokens = self.tokenizer.encode("<|eot|>")
52
+ self.stop_token_ids = tokens + particular_tokens
53
+
54
+ def _load_model(
55
+ self,
56
+ checkpoint: str,
57
+ quantization: None = None
58
+ ) -> None:
59
+ self.model = Llama4ForCausalLM.from_pretrained(
60
+ checkpoint,
61
+ dtype="auto",
62
+ device_map="auto",
63
+ attn_implementation="eager"
64
+ )
65
+
66
+ def load_checkpoint(
67
+ self,
68
+ checkpoint: str,
69
+ quantization: None = None
70
+ ) -> None:
71
+ return super().load_checkpoint(checkpoint, quantization)
72
+
73
+ def _build_input(
74
+ self,
75
+ data: LLaMA4Input
76
+ ) -> str:
77
+ if not self.tokenizer:
78
+ raise MissingEssentialProp("Could not find tokenizer.")
79
+
80
+ system_message = data.get("system_message", "")
81
+ if not system_message:
82
+ system_message = ""
83
+
84
+ if system_message:
85
+ system_message = f"{system_message}\n"
86
+
87
+ expected_answer = data.get("expected_answer")
88
+ answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
89
+
90
+ return (
91
+ f"<start_of_turn>user"
92
+ f"{system_message}\n{data["input_text"]}<end_of_turn>\n"
93
+ f"<start_of_turn>model\n"
94
+ f"{answer}"
95
+ )
96
+
97
+ def build_input(
98
+ self,
99
+ input_text: str,
100
+ system_message: str | None = None,
101
+ expected_answer: str | None = None,
102
+ image_paths: list[str] | None = None
103
+ ) -> LLaMA4Input:
104
+ if not self.tokenizer:
105
+ raise MissingEssentialProp("Could not find tokenizer.")
106
+
107
+ return {
108
+ "input_text": input_text,
109
+ "system_message": system_message,
110
+ "expected_answer": expected_answer,
111
+ "image_paths": image_paths
112
+ }
113
+
114
+ def dapt(
115
+ self,
116
+ train_dataset: list,
117
+ params: TrainParams | None = None,
118
+ eval_dataset: list | None = None,
119
+ save_at_end = True,
120
+ save_path: str | None = None
121
+ ) -> None:
122
+ if not self.model:
123
+ self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
124
+ return None
125
+ if not self.tokenizer:
126
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
127
+ return None
128
+
129
+ self._log("Starting DAPT")
130
+
131
+ if self.model_is_quantized:
132
+ self._log("Cannot DAPT a quantized model.", "WARNING")
133
+ return None
134
+
135
+ if params is None:
136
+ params = TrainParams()
137
+
138
+ training_arguments = TrainingArguments(
139
+ num_train_epochs=params.epochs,
140
+ learning_rate=params.lr,
141
+ gradient_accumulation_steps=params.gradient_accumulation,
142
+ warmup_ratio=params.warmup_ratio,
143
+ lr_scheduler_type="cosine_with_min_lr",
144
+ lr_scheduler_kwargs={"min_lr_rate": 0.1},
145
+ output_dir=None,
146
+ save_strategy="no",
147
+ logging_steps=params.logging_steps
148
+ )
149
+
150
+ if self.seed is not None:
151
+ training_arguments.seed = self.seed
152
+
153
+ processed_train_dataset = self._promptfy_dataset_for_dapt(train_dataset)
154
+ tokenized_train_dataset = self._tokenize_dataset_for_dapt(processed_train_dataset)
155
+
156
+ tokenized_eval_dataset = None
157
+ if eval_dataset:
158
+ processed_eval_dataset = self._promptfy_dataset_for_dapt(eval_dataset)
159
+ tokenized_eval_dataset = self._tokenize_dataset_for_dapt(processed_eval_dataset)
160
+
161
+ log_callback = LogCollectorCallback()
162
+
163
+ trainer = Trainer(
164
+ model=self.model,
165
+ train_dataset=tokenized_train_dataset,
166
+ eval_dataset=tokenized_eval_dataset,
167
+ args=training_arguments,
168
+ callbacks=[log_callback],
169
+ data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
170
+ )
171
+
172
+ trainer.train()
173
+
174
+ if save_at_end and save_path:
175
+ self.save_checkpoint(
176
+ path=save_path
177
+ )
178
+
179
+ self._log("Finished DAPT")
180
+
181
+ def fine_tune(
182
+ self,
183
+ train_dataset: list,
184
+ params: TrainParams | None = None,
185
+ eval_dataset: list | None = None,
186
+ save_at_end = True,
187
+ save_path: str | None = None
188
+ ) -> None:
189
+ self._log("Only 'dapt' method is available for this class. Redirecting call to it.", "WARNING")
190
+ return self.dapt(
191
+ train_dataset=train_dataset,
192
+ params=params,
193
+ eval_dataset=eval_dataset,
194
+ save_at_end=save_at_end,
195
+ save_path=save_path
196
+ )
197
+
198
+ def generate(
199
+ self,
200
+ input: LLaMA4Input | str,
201
+ params: GenerationParams | None = None
202
+ ) -> str | None:
203
+ if self.model is None or self.tokenizer is None:
204
+ self._log("Model or Tokenizer missing", "WARNING")
205
+ return None
206
+
207
+ self.model
208
+
209
+ self._log(f"Processing received input...'")
210
+
211
+ if params is None:
212
+ params = GenerationParams(max_new_tokens=32768)
213
+ elif params.max_new_tokens is None:
214
+ params.max_new_tokens = 32768
215
+
216
+ generation_params = create_generation_params(params)
217
+ self.model.generation_config = generation_params
218
+
219
+ if params:
220
+ generation_params = create_generation_params(params)
221
+ self.model.generation_config = generation_params
222
+
223
+ model_input = None
224
+ if isinstance(input, str):
225
+ model_input = self.build_input(
226
+ input_text=input
227
+ )
228
+ model_input = self._build_input(
229
+ data=model_input
230
+ )
231
+ else:
232
+ model_input = self._build_input(
233
+ data=input
234
+ )
235
+
236
+ tokenized_input = self._tokenize(model_input)
237
+ input_ids, attention_mask = tokenized_input
238
+
239
+ self.model.eval()
240
+ self.model.gradient_checkpointing_disable()
241
+
242
+ start = time()
243
+
244
+ with torch.no_grad():
245
+ outputs = self.model.generate(
246
+ input_ids=input_ids,
247
+ attention_mask=attention_mask,
248
+ use_cache=True,
249
+ eos_token_id=None,
250
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
251
+ )
252
+
253
+ end = time()
254
+ total_time = end - start
255
+
256
+ self._log(f"Response generated in {total_time:.4f} seconds")
257
+
258
+ response = outputs[0][input_ids.shape[1]:]
259
+
260
+ return self.tokenizer.decode(response, skip_special_tokens=True)
261
+
262
+ def generate_stream(
263
+ self,
264
+ input: LLaMA4Input | str,
265
+ params: GenerationParams | None = None
266
+ ) -> Iterator[str]:
267
+ if self.model is None or self.tokenizer is None:
268
+ self._log("Model or Tokenizer missing", "WARNING")
269
+ if False:
270
+ yield ""
271
+ return
272
+
273
+ if params is None:
274
+ params = GenerationParams(max_new_tokens=32768)
275
+ elif params.max_new_tokens is None:
276
+ params.max_new_tokens = 32768
277
+
278
+ generation_params = create_generation_params(params)
279
+ self.model.generation_config = generation_params
280
+
281
+ model_input = None
282
+ if isinstance(input, str):
283
+ model_input = self.build_input(
284
+ input_text=input
285
+ )
286
+ model_input = self._build_input(
287
+ data=model_input
288
+ )
289
+ else:
290
+ model_input = self._build_input(
291
+ data=input
292
+ )
293
+
294
+ tokenized_input = self._tokenize(model_input)
295
+ input_ids, attention_mask = tokenized_input
296
+
297
+ streamer = TextIteratorStreamer(
298
+ cast(AutoTokenizer, self.tokenizer),
299
+ skip_prompt=True,
300
+ skip_special_tokens=True
301
+ )
302
+
303
+ generate_fn = partial(
304
+ self.model.generate,
305
+ input_ids=input_ids,
306
+ attention_mask=attention_mask,
307
+ use_cache=True,
308
+ eos_token_id=None,
309
+ streamer=streamer,
310
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
311
+ )
312
+
313
+ thread = threading.Thread(target=generate_fn)
314
+ thread.start()
315
+
316
+ for new_text in streamer:
317
+ yield new_text
@@ -17,17 +17,31 @@ from llmflowstack.utils.exceptions import MissingEssentialProp
17
17
  from llmflowstack.utils.generation_utils import create_generation_params
18
18
 
19
19
 
20
- class GemmaInput(TypedDict):
20
+ class MedGemmaInput(TypedDict):
21
21
  input_text: str
22
22
  expected_answer: str | None
23
23
  system_message: str | None
24
24
 
25
- class Gemma(BaseModel):
25
+ class MedGemma(BaseModel):
26
26
  model: Gemma3ForCausalLM | None = None
27
27
  can_think = False
28
28
  question_fields = ["input_text", "system_message"]
29
29
  answer_fields = ["expected_answer"]
30
30
 
31
+ def __init__(
32
+ self,
33
+ checkpoint: str | None = None,
34
+ quantization: Literal["4bit"] | None = None,
35
+ seed: int | None = None,
36
+ log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
37
+ ) -> None:
38
+ return super().__init__(
39
+ checkpoint=checkpoint,
40
+ quantization=quantization,
41
+ seed=seed,
42
+ log_level=log_level
43
+ )
44
+
31
45
  def _set_generation_stopping_tokens(
32
46
  self,
33
47
  tokens: list[int]
@@ -41,17 +55,13 @@ class Gemma(BaseModel):
41
55
  def _load_model(
42
56
  self,
43
57
  checkpoint: str,
44
- quantization: Literal["8bit", "4bit"] | bool | None = None
58
+ quantization: Literal["4bit"] | None = None
45
59
  ) -> None:
46
60
  quantization_config = None
47
61
  if quantization == "4bit":
48
62
  quantization_config = BitsAndBytesConfig(
49
63
  load_in_4bit=True
50
64
  )
51
- if quantization == "8bit":
52
- quantization_config = BitsAndBytesConfig(
53
- load_in_8bit=True
54
- )
55
65
 
56
66
  self.model = Gemma3ForCausalLM.from_pretrained(
57
67
  checkpoint,
@@ -60,16 +70,22 @@ class Gemma(BaseModel):
60
70
  device_map="auto",
61
71
  attn_implementation="eager"
62
72
  )
73
+
74
+ def load_checkpoint(
75
+ self,
76
+ checkpoint: str,
77
+ quantization: Literal["4bit"] | None = None
78
+ ) -> None:
79
+ return super().load_checkpoint(checkpoint, quantization)
63
80
 
64
81
  def _build_input(
65
82
  self,
66
- input_text: str,
67
- expected_answer: str | None = None,
68
- system_message: str | None = None
83
+ data: MedGemmaInput
69
84
  ) -> str:
70
85
  if not self.tokenizer:
71
86
  raise MissingEssentialProp("Could not find tokenizer.")
72
87
 
88
+ system_message = data.get("system_message", "")
73
89
  if not system_message:
74
90
  system_message = ""
75
91
  if self.can_think:
@@ -78,11 +94,12 @@ class Gemma(BaseModel):
78
94
  if system_message:
79
95
  system_message = f"{system_message}\n"
80
96
 
97
+ expected_answer = data.get("expected_answer")
81
98
  answer = f"{expected_answer}<end_of_turn>" if expected_answer else ""
82
99
 
83
- return textwrap.dedent(
100
+ return (
84
101
  f"<start_of_turn>user"
85
- f"{system_message}\n{input_text}<end_of_turn>\n"
102
+ f"{system_message}\n{data["input_text"]}<end_of_turn>\n"
86
103
  f"<start_of_turn>model\n"
87
104
  f"{answer}"
88
105
  )
@@ -92,7 +109,7 @@ class Gemma(BaseModel):
92
109
  input_text: str,
93
110
  expected_answer: str | None = None,
94
111
  system_message: str | None = None
95
- ) -> GemmaInput:
112
+ ) -> MedGemmaInput:
96
113
  if not self.tokenizer:
97
114
  raise MissingEssentialProp("Could not find tokenizer.")
98
115
 
@@ -107,7 +124,7 @@ class Gemma(BaseModel):
107
124
 
108
125
  def generate(
109
126
  self,
110
- input: GemmaInput | str,
127
+ input: MedGemmaInput | str,
111
128
  params: GenerationParams | None = None,
112
129
  ) -> str | None:
113
130
  if self.model is None or self.tokenizer is None:
@@ -126,17 +143,18 @@ class Gemma(BaseModel):
126
143
 
127
144
  model_input = None
128
145
  if isinstance(input, str):
129
- model_input = self._build_input(
146
+ model_input = self.build_input(
130
147
  input_text=input
131
148
  )
149
+ model_input = self._build_input(
150
+ data=model_input
151
+ )
132
152
  else:
133
153
  model_input = self._build_input(
134
- input_text=input["input_text"],
135
- system_message=input["system_message"]
154
+ data=input
136
155
  )
137
156
 
138
157
  tokenized_input = self._tokenize(model_input)
139
-
140
158
  input_ids, attention_mask = tokenized_input
141
159
 
142
160
  self.model.eval()
@@ -174,7 +192,7 @@ class Gemma(BaseModel):
174
192
 
175
193
  def generate_stream(
176
194
  self,
177
- input: GemmaInput | str,
195
+ input: MedGemmaInput | str,
178
196
  params: GenerationParams | None = None
179
197
  ) -> Iterator[str]:
180
198
  if self.model is None or self.tokenizer is None:
@@ -191,14 +209,17 @@ class Gemma(BaseModel):
191
209
  generation_params = create_generation_params(params)
192
210
  self.model.generation_config = generation_params
193
211
 
212
+ model_input = None
194
213
  if isinstance(input, str):
195
- model_input = self._build_input(
214
+ model_input = self.build_input(
196
215
  input_text=input
197
216
  )
217
+ model_input = self._build_input(
218
+ data=model_input
219
+ )
198
220
  else:
199
221
  model_input = self._build_input(
200
- input_text=input["input_text"],
201
- system_message=input.get("system_message")
222
+ data=input
202
223
  )
203
224
 
204
225
  tokenized_input = self._tokenize(model_input)
@@ -0,0 +1,13 @@
1
+ from .Gemma import Gemma3
2
+ from .GPT_OSS import GPT_OSS
3
+ from .LLaMA3 import LLaMA3
4
+ from .LLaMA4 import LLaMA4
5
+ from .MedGemma import MedGemma
6
+
7
+ __all__ = [
8
+ "Gemma3",
9
+ "GPT_OSS",
10
+ "LLaMA3",
11
+ "LLaMA4",
12
+ "MedGemma"
13
+ ]
@@ -1,5 +1,7 @@
1
1
  import uuid
2
2
 
3
+ import chromadb
4
+ import chromadb.config
3
5
  from langchain_chroma import Chroma
4
6
  from langchain_core.documents import Document
5
7
  from langchain_core.embeddings import Embeddings
@@ -38,10 +40,15 @@ class RAGPipeline:
38
40
 
39
41
  self.encoder = SentenceTransformer(checkpoint, trust_remote_code=True)
40
42
 
43
+ client_settings = chromadb.config.Settings(
44
+ anonymized_telemetry=False
45
+ )
46
+
41
47
  self.vector_store = Chroma(
42
48
  collection_name=collection_name,
43
49
  embedding_function=EncoderWrapper(self.encoder),
44
- persist_directory=persist_directory
50
+ persist_directory=persist_directory,
51
+ client_settings=client_settings
45
52
  )
46
53
 
47
54
  self.splitter = RecursiveCharacterTextSplitter(
@@ -4,11 +4,11 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "llmflowstack"
7
- version = "1.0.2"
7
+ version = "1.1.1"
8
8
  authors = [
9
9
  { name = "Gustavo Henrique Ferreira Cruz", email = "gustavohferreiracruz@gmail.com" }
10
10
  ]
11
- description = "LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components."
11
+ description = "LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma, ...). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components."
12
12
  readme = "README.md"
13
13
  requires-python = ">=3.12"
14
14
  license = {text = "MIT"}
@@ -22,6 +22,7 @@ dependencies = [
22
22
  "datasets",
23
23
  "evaluate",
24
24
  "huggingface-hub",
25
+ "kernels",
25
26
  "langchain-chroma",
26
27
  "langchain_community",
27
28
  "nltk",
@@ -1,9 +0,0 @@
1
- from .Gemma import Gemma
2
- from .GPT_OSS import GPT_OSS
3
- from .LLaMA3 import LLaMA3
4
-
5
- __all__ = [
6
- "Gemma",
7
- "GPT_OSS",
8
- "LLaMA3"
9
- ]
File without changes
File without changes