llmflowstack 1.1.4__tar.gz → 1.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/PKG-INFO +1 -2
  2. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/__init__.py +8 -8
  3. llmflowstack-1.1.4/llmflowstack/base/base.py → llmflowstack-1.2.1/llmflowstack/decoders/BaseDecoder.py +27 -56
  4. {llmflowstack-1.1.4/llmflowstack/models → llmflowstack-1.2.1/llmflowstack/decoders}/GPT_OSS.py +35 -19
  5. {llmflowstack-1.1.4/llmflowstack/models → llmflowstack-1.2.1/llmflowstack/decoders}/Gemma.py +22 -14
  6. {llmflowstack-1.1.4/llmflowstack/models → llmflowstack-1.2.1/llmflowstack/decoders}/LLaMA3.py +30 -24
  7. {llmflowstack-1.1.4/llmflowstack/models → llmflowstack-1.2.1/llmflowstack/decoders}/LLaMA4.py +23 -15
  8. {llmflowstack-1.1.4/llmflowstack/models → llmflowstack-1.2.1/llmflowstack/decoders}/MedGemma.py +18 -11
  9. llmflowstack-1.2.1/llmflowstack/rag/VectorDatabase.py +278 -0
  10. llmflowstack-1.2.1/llmflowstack/rag/__init__.py +5 -0
  11. llmflowstack-1.2.1/llmflowstack/utils/logging.py +8 -0
  12. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/pyproject.toml +1 -2
  13. llmflowstack-1.1.4/llmflowstack/callbacks/__init__.py +0 -0
  14. llmflowstack-1.1.4/llmflowstack/rag/__iinit__.py +0 -5
  15. llmflowstack-1.1.4/llmflowstack/rag/pipeline.py +0 -279
  16. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/.github/workflows/python-publish.yml +0 -0
  17. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/.gitignore +0 -0
  18. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/LICENSE +0 -0
  19. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/README.md +0 -0
  20. {llmflowstack-1.1.4/llmflowstack/base → llmflowstack-1.2.1/llmflowstack/callbacks}/__init__.py +0 -0
  21. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/callbacks/log_collector.py +0 -0
  22. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/callbacks/stop_on_token.py +0 -0
  23. {llmflowstack-1.1.4/llmflowstack/models → llmflowstack-1.2.1/llmflowstack/decoders}/__init__.py +0 -0
  24. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/schemas/__init__.py +0 -0
  25. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/schemas/params.py +0 -0
  26. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/utils/__init__.py +0 -0
  27. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/utils/evaluation_methods.py +0 -0
  28. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/utils/exceptions.py +0 -0
  29. {llmflowstack-1.1.4 → llmflowstack-1.2.1}/llmflowstack/utils/generation_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: llmflowstack
3
- Version: 1.1.4
3
+ Version: 1.2.1
4
4
  Summary: LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma, ...). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components.
5
5
  Author-email: Gustavo Henrique Ferreira Cruz <gustavohferreiracruz@gmail.com>
6
6
  License: MIT
@@ -10,7 +10,6 @@ Requires-Dist: accelerate
10
10
  Requires-Dist: bert-score
11
11
  Requires-Dist: bitsandbytes
12
12
  Requires-Dist: chromadb
13
- Requires-Dist: colorama
14
13
  Requires-Dist: datasets
15
14
  Requires-Dist: evaluate
16
15
  Requires-Dist: huggingface-hub
@@ -1,9 +1,9 @@
1
- from .models.Gemma import Gemma3
2
- from .models.GPT_OSS import GPT_OSS
3
- from .models.LLaMA3 import LLaMA3
4
- from .models.LLaMA4 import LLaMA4
5
- from .models.MedGemma import MedGemma
6
- from .rag.pipeline import RAGPipeline
1
+ from .decoders.Gemma import Gemma3
2
+ from .decoders.GPT_OSS import GPT_OSS
3
+ from .decoders.LLaMA3 import LLaMA3
4
+ from .decoders.LLaMA4 import LLaMA4
5
+ from .decoders.MedGemma import MedGemma
6
+ from .rag import VectorDatabase
7
7
  from .schemas.params import (GenerationBeamsParams, GenerationParams,
8
8
  GenerationSampleParams, TrainParams)
9
9
  from .utils.evaluation_methods import text_evaluation
@@ -14,10 +14,10 @@ __all__ = [
14
14
  "LLaMA3",
15
15
  "LLaMA4",
16
16
  "MedGemma",
17
- "RAGPipeline",
18
17
  "GenerationBeamsParams",
19
18
  "GenerationParams",
20
19
  "GenerationSampleParams",
21
20
  "TrainParams",
22
- "text_evaluation"
21
+ "text_evaluation",
22
+ "VectorDatabase"
23
23
  ]
@@ -1,15 +1,14 @@
1
1
  import gc
2
2
  import json
3
- import logging
4
3
  import os
5
4
  import random
6
5
  from abc import ABC, abstractmethod
6
+ from logging import getLogger
7
7
  from typing import Any, Literal, cast
8
8
  from uuid import uuid4
9
9
 
10
10
  import numpy as np
11
11
  import torch
12
- from colorama import Fore, Style, init
13
12
  from datasets import Dataset
14
13
  from torch import Tensor
15
14
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
@@ -20,15 +19,15 @@ from trl.trainer.sft_trainer import SFTTrainer
20
19
  from llmflowstack.callbacks.log_collector import LogCollectorCallback
21
20
  from llmflowstack.schemas.params import GenerationParams, TrainParams
22
21
  from llmflowstack.utils.exceptions import MissingEssentialProp
22
+ from llmflowstack.utils.logging import LogLevel
23
23
 
24
24
 
25
- class BaseModel(ABC):
25
+ class BaseDecoder(ABC):
26
26
  model = None
27
27
  tokenizer = None
28
28
  _model_id = None
29
29
  model_is_quantized = None
30
30
  seed = None
31
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO"
32
31
  stop_token_ids = []
33
32
  question_fields = []
34
33
  answer_fields = []
@@ -37,20 +36,17 @@ class BaseModel(ABC):
37
36
  self,
38
37
  checkpoint: str | None = None,
39
38
  quantization: Literal["4bit", "8bit"] | bool | None = None,
40
- seed: int | None = None,
41
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
39
+ seed: int | None = None
42
40
  ) -> None:
43
41
  if not self.question_fields or not self.answer_fields:
44
42
  raise NotImplementedError("Subclasses must define question_fields and answer_fields.")
45
43
 
46
- init(autoreset=True)
47
44
  if seed:
48
45
  self._set_seed(seed)
49
46
 
50
47
  self._base_model = checkpoint
51
48
 
52
- self._set_logger(log_level)
53
- self.log_level = log_level
49
+ self.logger = getLogger(f"LLMFlowStack.{self.__class__.__name__}")
54
50
 
55
51
  self.tokenizer: PreTrainedTokenizerBase | None = None
56
52
 
@@ -61,6 +57,17 @@ class BaseModel(ABC):
61
57
  quantization=quantization
62
58
  )
63
59
 
60
+ def _log(
61
+ self,
62
+ message: str,
63
+ level: LogLevel = LogLevel.INFO,
64
+ ) -> None:
65
+ log_func = getattr(self.logger, level.lower(), None)
66
+ if log_func:
67
+ log_func(message)
68
+ else:
69
+ self.logger.info(message)
70
+
64
71
  @abstractmethod
65
72
  def _load_model(
66
73
  self,
@@ -84,7 +91,7 @@ class BaseModel(ABC):
84
91
  quantization: Any
85
92
  ) -> None:
86
93
  if self.model:
87
- self._log("A model is already loaded. Attempting to reset it.", "WARNING")
94
+ self._log("A model is already loaded. Attempting to reset it.", LogLevel.WARNING)
88
95
  self.unload_model()
89
96
 
90
97
  self._log(f"Loading model on '{checkpoint}'")
@@ -132,42 +139,6 @@ class BaseModel(ABC):
132
139
  ) -> None:
133
140
  self._model_id = uuid4()
134
141
 
135
- def _set_logger(
136
- self,
137
- level: str
138
- ) -> None:
139
- level_map = {
140
- "DEBUG": logging.DEBUG,
141
- "INFO": logging.INFO,
142
- "WARNING": logging.WARNING,
143
- "ERROR": logging.ERROR,
144
- }
145
- numeric_level = level_map.get(level.upper(), logging.INFO)
146
-
147
- logging.basicConfig(
148
- level=numeric_level,
149
- format="%(asctime)s - %(levelname)s - %(message)s"
150
- )
151
- self.logger = logging.getLogger(__name__)
152
-
153
- def _log(
154
- self,
155
- info: str,
156
- level: Literal["INFO", "WARNING", "ERROR", "DEBUG"] = "INFO"
157
- ) -> None:
158
- if level == "INFO":
159
- colored_msg = f"{Fore.GREEN}{info}{Style.RESET_ALL}"
160
- self.logger.info(colored_msg)
161
- elif level == "WARNING":
162
- colored_msg = f"{Fore.YELLOW}{info}{Style.RESET_ALL}"
163
- self.logger.warning(colored_msg)
164
- elif level == "ERROR":
165
- colored_msg = f"{Fore.RED}{info}{Style.RESET_ALL}"
166
- self.logger.error(colored_msg)
167
- elif level == "DEBUG":
168
- colored_msg = f"{Fore.BLUE}{info}{Style.RESET_ALL}"
169
- self.logger.debug(colored_msg)
170
-
171
142
  def _set_seed(
172
143
  self,
173
144
  seed: int
@@ -190,10 +161,10 @@ class BaseModel(ABC):
190
161
  path: str
191
162
  ) -> None:
192
163
  if not self.model:
193
- self._log("No model to save.", "WARNING")
164
+ self._log("No model to save.", LogLevel.WARNING)
194
165
  return None
195
166
  if not self.tokenizer:
196
- self._log("No tokenizer to save.", "WARNING")
167
+ self._log("No tokenizer to save.", LogLevel.WARNING)
197
168
  return None
198
169
 
199
170
  os.makedirs(path, exist_ok=True)
@@ -299,16 +270,16 @@ class BaseModel(ABC):
299
270
  save_path: str | None = None
300
271
  ) -> None:
301
272
  if not self.model:
302
- self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
273
+ self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
303
274
  return None
304
275
  if not self.tokenizer:
305
- self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
276
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
306
277
  return None
307
278
 
308
279
  self._log("Starting DAPT")
309
280
 
310
281
  if self.model_is_quantized:
311
- self._log("Cannot DAPT a quantized model.", "WARNING")
282
+ self._log("Cannot DAPT a quantized model.", LogLevel.WARNING)
312
283
  return None
313
284
 
314
285
  if params is None:
@@ -443,16 +414,16 @@ class BaseModel(ABC):
443
414
  save_path: str | None = None
444
415
  ) -> None:
445
416
  if not self.model:
446
- self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
417
+ self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
447
418
  return None
448
419
  if not self.tokenizer:
449
- self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
420
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
450
421
  return None
451
422
 
452
423
  self._log("Starting fine-tune")
453
424
 
454
425
  if self.model_is_quantized:
455
- self._log("Cannot fine-tune a quantized model.", "WARNING")
426
+ self._log("Cannot fine-tune a quantized model.", LogLevel.WARNING)
456
427
  return None
457
428
 
458
429
  if params is None:
@@ -521,8 +492,8 @@ class BaseModel(ABC):
521
492
  self._model_id = None
522
493
  self._log("Reset successfully.")
523
494
  except Exception as e:
524
- self._log("Couldn't reset model...", "ERROR")
525
- self._log(f"{str(e)}", "DEBUG")
495
+ self._log("Couldn't reset model...", LogLevel.ERROR)
496
+ self._log(f"{str(e)}", LogLevel.DEBUG)
526
497
 
527
498
  def set_seed(self, seed: int) -> None:
528
499
  self._log(f"Setting seed value {seed}")
@@ -1,4 +1,3 @@
1
- import textwrap
2
1
  import threading
3
2
  from functools import partial
4
3
  from time import time
@@ -11,11 +10,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
11
10
  from transformers.models.gpt_oss import GptOssForCausalLM
12
11
  from transformers.utils.quantization_config import Mxfp4Config
13
12
 
14
- from llmflowstack.base.base import BaseModel
15
13
  from llmflowstack.callbacks.stop_on_token import StopOnToken
14
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
16
15
  from llmflowstack.schemas.params import GenerationParams
17
16
  from llmflowstack.utils.exceptions import MissingEssentialProp
18
17
  from llmflowstack.utils.generation_utils import create_generation_params
18
+ from llmflowstack.utils.logging import LogLevel
19
19
 
20
20
 
21
21
  class GPTOSSInput(TypedDict):
@@ -24,11 +24,11 @@ class GPTOSSInput(TypedDict):
24
24
  developer_message: str | None
25
25
  expected_answer: str | None
26
26
  reasoning_message: str | None
27
- reasoning_level: Literal["Low", "Medium", "High"] | None
27
+ reasoning_level: Literal["Low", "Medium", "High", "Off"] | None
28
28
 
29
- class GPT_OSS(BaseModel):
29
+ class GPT_OSS(BaseDecoder):
30
30
  model: GptOssForCausalLM | None = None
31
- reasoning_level: Literal["Low", "Medium", "High"] = "Low"
31
+ reasoning_level: Literal["Low", "Medium", "High", "Off"] = "Low"
32
32
  question_fields = ["input_text", "developer_message", "system_message"]
33
33
  answer_fields = ["expected_answer", "reasoning_message"]
34
34
 
@@ -36,14 +36,12 @@ class GPT_OSS(BaseModel):
36
36
  self,
37
37
  checkpoint: str | None = None,
38
38
  quantization: bool | None = None,
39
- seed: int | None = None,
40
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
39
+ seed: int | None = None
41
40
  ) -> None:
42
41
  return super().__init__(
43
42
  checkpoint=checkpoint,
44
43
  quantization=quantization,
45
- seed=seed,
46
- log_level=log_level
44
+ seed=seed
47
45
  )
48
46
 
49
47
  def _set_generation_stopping_tokens(
@@ -51,7 +49,7 @@ class GPT_OSS(BaseModel):
51
49
  tokens: list[int]
52
50
  ) -> None:
53
51
  if not self.tokenizer:
54
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
52
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
55
53
  return None
56
54
  encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
57
55
  particular_tokens = encoding.stop_tokens_for_assistant_actions()
@@ -76,7 +74,7 @@ class GPT_OSS(BaseModel):
76
74
  attn_implementation="eager",
77
75
  )
78
76
  except Exception as _:
79
- self._log("Error trying to load the model. Defaulting to load without quantization...", "WARNING")
77
+ self._log("Error trying to load the model. Defaulting to load without quantization...", LogLevel.WARNING)
80
78
  self.model = GptOssForCausalLM.from_pretrained(
81
79
  checkpoint,
82
80
  dtype="auto",
@@ -104,6 +102,8 @@ class GPT_OSS(BaseModel):
104
102
 
105
103
  system_message = data.get("system_message", "")
106
104
  system_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\n\nReasoning: {reasoning}\n\n{system_message}# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>"
105
+ if reasoning == "Off":
106
+ system_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\n\n{system_message}# Valid channels: final. Channel must be included for every message.<|end|>"
107
107
 
108
108
  developer_text = ""
109
109
  developer_message = data.get("developer_message", "")
@@ -119,7 +119,14 @@ class GPT_OSS(BaseModel):
119
119
  if expected_answer:
120
120
  assistant_text += f"<|start|>assistant<|channel|>final<|message|>{expected_answer}<|return|>"
121
121
 
122
- return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{data["input_text"]}<|end|>{assistant_text}""")
122
+ if not expected_answer and reasoning == "Off":
123
+ assistant_text += "<|start|>assistant<|channel|>final<|message|>"
124
+
125
+ return (
126
+ f"{system_text}{developer_text}"
127
+ f"<|start|>user<|message|>{data["input_text"]}<|end|>"
128
+ f"{assistant_text}"
129
+ )
123
130
 
124
131
  def build_input(
125
132
  self,
@@ -128,7 +135,7 @@ class GPT_OSS(BaseModel):
128
135
  developer_message: str | None = None,
129
136
  expected_answer: str | None = None,
130
137
  reasoning_message: str | None = None,
131
- reasoning_level: Literal["Low", "Medium", "High"] | None = None
138
+ reasoning_level: Literal["Low", "Medium", "High", "Off"] | None = None
132
139
  ) -> GPTOSSInput:
133
140
  if not self.tokenizer:
134
141
  raise MissingEssentialProp("Could not find tokenizer.")
@@ -144,7 +151,7 @@ class GPT_OSS(BaseModel):
144
151
 
145
152
  def set_reasoning_level(
146
153
  self,
147
- level: Literal["Low", "Medium", "High"]
154
+ level: Literal["Low", "Medium", "High", "Off"]
148
155
  ) -> None:
149
156
  self.reasoning_level = level
150
157
 
@@ -154,7 +161,7 @@ class GPT_OSS(BaseModel):
154
161
  params: GenerationParams | None = None
155
162
  ) -> str | None:
156
163
  if self.model is None or self.tokenizer is None:
157
- self._log("Model or Tokenizer missing", "WARNING")
164
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
158
165
  return None
159
166
 
160
167
  self._log(f"Processing received input...'")
@@ -222,11 +229,13 @@ class GPT_OSS(BaseModel):
222
229
  params: GenerationParams | None = None
223
230
  ) -> Iterator[str]:
224
231
  if self.model is None or self.tokenizer is None:
225
- self._log("Model or Tokenizer missing", "WARNING")
232
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
226
233
  if False:
227
234
  yield ""
228
235
  return
229
236
 
237
+ self._log(f"Processing received input...'")
238
+
230
239
  if params is None:
231
240
  params = GenerationParams(max_new_tokens=32768)
232
241
  elif params.max_new_tokens is None:
@@ -266,19 +275,26 @@ class GPT_OSS(BaseModel):
266
275
  stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
267
276
  )
268
277
 
278
+ start = time()
279
+
269
280
  thread = threading.Thread(target=generate_fn)
270
281
  thread.start()
271
282
 
272
- done_thinking = False
283
+ done_thinking = self.reasoning_level == "Off"
273
284
  buffer = ""
274
285
 
275
286
  for new_text in streamer:
276
287
  buffer += new_text
277
288
 
278
- if "final" in buffer:
289
+ if "final" in buffer and not done_thinking:
279
290
  done_thinking = True
280
291
  buffer = buffer.split("final", 1)[1]
281
292
 
282
293
  if done_thinking:
283
294
  yield buffer
284
- buffer = ""
295
+ buffer = ""
296
+
297
+ end = time()
298
+ total_time = end - start
299
+
300
+ self._log(f"Response generated in {total_time:.4f} seconds")
@@ -10,12 +10,13 @@ from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
10
10
  from transformers.models.gemma3 import Gemma3ForCausalLM
11
11
  from transformers.utils.quantization_config import BitsAndBytesConfig
12
12
 
13
- from llmflowstack.base.base import BaseModel
14
13
  from llmflowstack.callbacks.log_collector import LogCollectorCallback
15
14
  from llmflowstack.callbacks.stop_on_token import StopOnToken
15
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
16
16
  from llmflowstack.schemas.params import GenerationParams, TrainParams
17
17
  from llmflowstack.utils.exceptions import MissingEssentialProp
18
18
  from llmflowstack.utils.generation_utils import create_generation_params
19
+ from llmflowstack.utils.logging import LogLevel
19
20
 
20
21
 
21
22
  class Gemma3Input(TypedDict):
@@ -24,7 +25,7 @@ class Gemma3Input(TypedDict):
24
25
  system_message: str | None
25
26
  image_paths: list[str] | None
26
27
 
27
- class Gemma3(BaseModel):
28
+ class Gemma3(BaseDecoder):
28
29
  model: Gemma3ForCausalLM | None = None
29
30
  question_fields = ["input_text", "system_message"]
30
31
  answer_fields = ["expected_answer"]
@@ -33,14 +34,12 @@ class Gemma3(BaseModel):
33
34
  self,
34
35
  checkpoint: str | None = None,
35
36
  quantization: Literal["4bit"] | None = None,
36
- seed: int | None = None,
37
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
37
+ seed: int | None = None
38
38
  ) -> None:
39
39
  return super().__init__(
40
40
  checkpoint=checkpoint,
41
41
  quantization=quantization,
42
- seed=seed,
43
- log_level=log_level
42
+ seed=seed
44
43
  )
45
44
 
46
45
  def _set_generation_stopping_tokens(
@@ -48,7 +47,7 @@ class Gemma3(BaseModel):
48
47
  tokens: list[int]
49
48
  ) -> None:
50
49
  if not self.tokenizer:
51
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
50
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
52
51
  return None
53
52
  particular_tokens = self.tokenizer.encode("<end_of_turn>")
54
53
  self.stop_token_ids = tokens + particular_tokens
@@ -129,16 +128,16 @@ class Gemma3(BaseModel):
129
128
  save_path: str | None = None
130
129
  ) -> None:
131
130
  if not self.model:
132
- self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
131
+ self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
133
132
  return None
134
133
  if not self.tokenizer:
135
- self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
134
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
136
135
  return None
137
136
 
138
137
  self._log("Starting Training")
139
138
 
140
139
  if self.model_is_quantized:
141
- self._log("Cannot traub a quantized model.", "WARNING")
140
+ self._log("Cannot traub a quantized model.", LogLevel.WARNING)
142
141
  return None
143
142
 
144
143
  if params is None:
@@ -195,7 +194,7 @@ class Gemma3(BaseModel):
195
194
  save_at_end = True,
196
195
  save_path: str | None = None
197
196
  ) -> None:
198
- self._log("Only 'dapt' method is available for this class. Redirecting call to it.", "WARNING")
197
+ self._log("Only 'dapt' method is available for this class. Redirecting call to it.", LogLevel.WARNING)
199
198
  return self.dapt(
200
199
  train_dataset=train_dataset,
201
200
  params=params,
@@ -210,7 +209,7 @@ class Gemma3(BaseModel):
210
209
  params: GenerationParams | None = None,
211
210
  ) -> str | None:
212
211
  if self.model is None or self.tokenizer is None:
213
- self._log("Model or Tokenizer missing", "WARNING")
212
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
214
213
  return None
215
214
 
216
215
  self._log(f"Processing received input...'")
@@ -267,10 +266,12 @@ class Gemma3(BaseModel):
267
266
  params: GenerationParams | None = None
268
267
  ) -> Iterator[str]:
269
268
  if self.model is None or self.tokenizer is None:
270
- self._log("Model or Tokenizer missing", "WARNING")
269
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
271
270
  if False:
272
271
  yield ""
273
272
  return
273
+
274
+ self._log(f"Processing received input...'")
274
275
 
275
276
  if params is None:
276
277
  params = GenerationParams(max_new_tokens=32768)
@@ -312,8 +313,15 @@ class Gemma3(BaseModel):
312
313
  stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
313
314
  )
314
315
 
316
+ start = time()
317
+
315
318
  thread = threading.Thread(target=generate_fn)
316
319
  thread.start()
317
320
 
318
321
  for new_text in streamer:
319
- yield new_text
322
+ yield new_text
323
+
324
+ end = time()
325
+ total_time = end - start
326
+
327
+ self._log(f"Response generated in {total_time:.4f} seconds")
@@ -1,5 +1,5 @@
1
- import textwrap
2
1
  import threading
2
+ from functools import partial
3
3
  from time import time
4
4
  from typing import Iterator, Literal, TypedDict, cast
5
5
 
@@ -9,11 +9,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
9
9
  from transformers.models.llama import LlamaForCausalLM
10
10
  from transformers.utils.quantization_config import BitsAndBytesConfig
11
11
 
12
- from llmflowstack.base.base import BaseModel
13
12
  from llmflowstack.callbacks.stop_on_token import StopOnToken
13
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
14
14
  from llmflowstack.schemas.params import GenerationParams
15
15
  from llmflowstack.utils.exceptions import MissingEssentialProp
16
16
  from llmflowstack.utils.generation_utils import create_generation_params
17
+ from llmflowstack.utils.logging import LogLevel
17
18
 
18
19
 
19
20
  class LLaMA3Input(TypedDict):
@@ -21,7 +22,7 @@ class LLaMA3Input(TypedDict):
21
22
  expected_answer: str | None
22
23
  system_message: str | None
23
24
 
24
- class LLaMA3(BaseModel):
25
+ class LLaMA3(BaseDecoder):
25
26
  model: LlamaForCausalLM | None = None
26
27
  question_fields = ["input_text", "system_message"]
27
28
  answer_fields = ["expected_answer"]
@@ -30,14 +31,12 @@ class LLaMA3(BaseModel):
30
31
  self,
31
32
  checkpoint: str | None = None,
32
33
  quantization: Literal["4bit", "8bit"] | None = None,
33
- seed: int | None = None,
34
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
34
+ seed: int | None = None
35
35
  ) -> None:
36
36
  return super().__init__(
37
37
  checkpoint=checkpoint,
38
38
  quantization=quantization,
39
- seed=seed,
40
- log_level=log_level
39
+ seed=seed
41
40
  )
42
41
 
43
42
  def _set_generation_stopping_tokens(
@@ -45,7 +44,7 @@ class LLaMA3(BaseModel):
45
44
  tokens: list[int]
46
45
  ) -> None:
47
46
  if not self.tokenizer:
48
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
47
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
49
48
  return None
50
49
  particular_tokens = self.tokenizer.encode("<|eot_id|>")
51
50
  self.stop_token_ids = tokens + particular_tokens
@@ -92,7 +91,7 @@ class LLaMA3(BaseModel):
92
91
 
93
92
  system_message = data.get("system_message", "")
94
93
 
95
- return textwrap.dedent(
94
+ return (
96
95
  f"<|start_header_id|>system<|end_header_id|>{system_message}\n"
97
96
  f"<|eot_id|><|start_header_id|>user<|end_header_id|>{data["input_text"]}\n"
98
97
  f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>{answer}"
@@ -119,7 +118,7 @@ class LLaMA3(BaseModel):
119
118
  params: GenerationParams | None = None
120
119
  ) -> str | None:
121
120
  if self.model is None or self.tokenizer is None:
122
- self._log("Model or Tokenizer missing", "WARNING")
121
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
123
122
  return None
124
123
 
125
124
  self.model
@@ -184,11 +183,13 @@ class LLaMA3(BaseModel):
184
183
  params: GenerationParams | None = None
185
184
  ) -> Iterator[str]:
186
185
  if self.model is None or self.tokenizer is None:
187
- self._log("Model or Tokenizer missing", "WARNING")
186
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
188
187
  if False:
189
188
  yield ""
190
189
  return
191
190
 
191
+ self._log(f"Processing received input...'")
192
+
192
193
  if params is None:
193
194
  params = GenerationParams(max_new_tokens=8192)
194
195
  elif params.max_new_tokens is None:
@@ -219,20 +220,25 @@ class LLaMA3(BaseModel):
219
220
  skip_special_tokens=True
220
221
  )
221
222
 
222
- def _generate() -> None:
223
- assert self.model is not None
224
- with torch.no_grad():
225
- self.model.generate(
226
- input_ids=input_ids,
227
- attention_mask=attention_mask,
228
- use_cache=True,
229
- eos_token_id=None,
230
- streamer=streamer,
231
- stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
232
- )
223
+ generate_fn = partial(
224
+ self.model.generate,
225
+ input_ids=input_ids,
226
+ attention_mask=attention_mask,
227
+ use_cache=True,
228
+ eos_token_id=None,
229
+ streamer=streamer,
230
+ stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
231
+ )
233
232
 
234
- thread = threading.Thread(target=_generate)
233
+ start = time()
234
+
235
+ thread = threading.Thread(target=generate_fn)
235
236
  thread.start()
236
237
 
237
238
  for new_text in streamer:
238
- yield new_text
239
+ yield new_text
240
+
241
+ end = time()
242
+ total_time = end - start
243
+
244
+ self._log(f"Response generated in {total_time:.4f} seconds")