llmflowstack 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ from .models.Gemma import Gemma
2
+ from .models.GPT_OSS import GPT_OSS
3
+ from .models.LLaMA3 import LLaMA3
4
+ from .rag.pipeline import RAGPipeline
5
+ from .schemas.params import (GenerationBeamsParams, GenerationParams,
6
+ GenerationSampleParams, TrainParams)
7
+ from .utils.evaluation_methods import text_evaluation
8
+
9
+ __all__ = [
10
+ "Gemma",
11
+ "LLaMA3",
12
+ "GPT_OSS",
13
+ "RAGPipeline",
14
+ "GenerationBeamsParams",
15
+ "GenerationParams",
16
+ "GenerationSampleParams",
17
+ "TrainParams",
18
+ "text_evaluation"
19
+ ]
File without changes
@@ -0,0 +1,527 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Literal, cast
7
+ from uuid import uuid4
8
+
9
+ import numpy as np
10
+ import torch
11
+ from colorama import Fore, Style, init
12
+ from datasets import Dataset
13
+ from torch import Tensor
14
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
15
+ from transformers.tokenization_utils_base import BatchEncoding
16
+ from trl.trainer.sft_config import SFTConfig
17
+ from trl.trainer.sft_trainer import SFTTrainer
18
+
19
+ from llmflowstack.callbacks.log_collector import LogCollectorCallback
20
+ from llmflowstack.schemas.params import GenerationParams, TrainParams
21
+ from llmflowstack.utils.exceptions import MissingEssentialProp
22
+
23
+
24
+ class BaseModel(ABC):
25
+ model = None
26
+ tokenizer = None
27
+ _model_id = None
28
+ model_is_quantized = None
29
+ seed = None
30
+ log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO"
31
+ stop_token_ids = []
32
+ question_fields = []
33
+ answer_fields = []
34
+
35
+ def __init__(
36
+ self,
37
+ checkpoint: str | None = None,
38
+ quantization: Literal["8bit", "4bit"] | bool | None = None,
39
+ seed: int | None = None,
40
+ log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
41
+ ) -> None:
42
+ if not self.question_fields or not self.answer_fields:
43
+ raise NotImplementedError("Subclasses must define question_fields and answer_fields.")
44
+
45
+ init(autoreset=True)
46
+ if seed:
47
+ self._set_seed(seed)
48
+
49
+ self._base_model = checkpoint
50
+
51
+ self._set_logger(log_level)
52
+ self.log_level = log_level
53
+
54
+ self.tokenizer: PreTrainedTokenizerBase | None = None
55
+
56
+ if checkpoint:
57
+ self._checkpoint = checkpoint
58
+ self.load_checkpoint(
59
+ checkpoint=checkpoint,
60
+ quantization=quantization
61
+ )
62
+
63
+ @abstractmethod
64
+ def _load_model(
65
+ self,
66
+ checkpoint: str,
67
+ quantization: Literal["8bit", "4bit"] | bool | None = None
68
+ ) -> None:
69
+ pass
70
+
71
+ def _load_tokenizer(self, checkpoint: str) -> None:
72
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
73
+ tokenizer.pad_token = tokenizer.eos_token
74
+ tokenizer.add_eos_token = True
75
+ tokenizer.padding_side = "right"
76
+
77
+ self.tokenizer = tokenizer
78
+
79
+ def load_checkpoint(
80
+ self,
81
+ checkpoint: str,
82
+ quantization: Literal["8bit", "4bit"] | bool | None = None
83
+ ) -> None:
84
+ if self.model:
85
+ self._log("A model is already loaded. Attempting to reset it.", "WARNING")
86
+ self.unload_model()
87
+
88
+ self._log(f"Loading model on '{checkpoint}'")
89
+
90
+ self._load_tokenizer(checkpoint)
91
+ self._load_model(
92
+ checkpoint=checkpoint,
93
+ quantization=quantization
94
+ )
95
+
96
+ self._log("Model & Tokenizer loaded")
97
+
98
+ if quantization:
99
+ self.model_is_quantized = True
100
+
101
+ if not self._model_id:
102
+ self._create_model_id()
103
+
104
+ stop_tokens = []
105
+ pad_token_id = getattr(self.tokenizer, "pad_token_id", None)
106
+ if pad_token_id:
107
+ stop_tokens.append(pad_token_id)
108
+ eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
109
+ if eos_token_id:
110
+ stop_tokens.append(eos_token_id)
111
+
112
+ self._set_generation_stopping_tokens(stop_tokens)
113
+ self.stop_token_ids = list(set(self.stop_token_ids))
114
+
115
+ def from_pretrained(
116
+ self,
117
+ checkpoint: str,
118
+ quantization: Literal["8bit", "4bit"] | bool | None = None
119
+ ) -> None:
120
+ self.load_checkpoint(
121
+ checkpoint=checkpoint,
122
+ quantization=quantization
123
+ )
124
+ with open(os.path.join(checkpoint, "custom_info.json"), "r") as f:
125
+ data = json.load(f)
126
+ self._model_id = data.get("model_id", None)
127
+
128
+ def _create_model_id(
129
+ self
130
+ ) -> None:
131
+ self._model_id = uuid4()
132
+
133
+ def _set_logger(
134
+ self,
135
+ level: str
136
+ ) -> None:
137
+ level_map = {
138
+ "DEBUG": logging.DEBUG,
139
+ "INFO": logging.INFO,
140
+ "WARNING": logging.WARNING,
141
+ "ERROR": logging.ERROR,
142
+ }
143
+ numeric_level = level_map.get(level.upper(), logging.INFO)
144
+
145
+ logging.basicConfig(
146
+ level=numeric_level,
147
+ format="%(asctime)s - %(levelname)s - %(message)s"
148
+ )
149
+ self.logger = logging.getLogger(__name__)
150
+
151
+ def _log(
152
+ self,
153
+ info: str,
154
+ level: Literal["INFO", "WARNING", "ERROR", "DEBUG"] = "INFO"
155
+ ) -> None:
156
+ if level == "INFO":
157
+ colored_msg = f"{Fore.GREEN}{info}{Style.RESET_ALL}"
158
+ self.logger.info(colored_msg)
159
+ elif level == "WARNING":
160
+ colored_msg = f"{Fore.YELLOW}{info}{Style.RESET_ALL}"
161
+ self.logger.warning(colored_msg)
162
+ elif level == "ERROR":
163
+ colored_msg = f"{Fore.RED}{info}{Style.RESET_ALL}"
164
+ self.logger.error(colored_msg)
165
+ elif level == "DEBUG":
166
+ colored_msg = f"{Fore.BLUE}{info}{Style.RESET_ALL}"
167
+ self.logger.debug(colored_msg)
168
+
169
+ def _set_seed(
170
+ self,
171
+ seed: int
172
+ ) -> None:
173
+ self.seed = seed
174
+ random.seed(seed)
175
+ np.random.seed(seed)
176
+ torch.manual_seed(seed)
177
+ torch.cuda.manual_seed(seed)
178
+
179
+ torch.backends.cudnn.deterministic = True
180
+ torch.backends.cudnn.benchmark = False
181
+ torch.use_deterministic_algorithms(True)
182
+
183
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
184
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
185
+
186
+ def save_checkpoint(
187
+ self,
188
+ path: str
189
+ ) -> None:
190
+ if not self.model:
191
+ self._log("No model to save.", "WARNING")
192
+ return None
193
+ if not self.tokenizer:
194
+ self._log("No tokenizer to save.", "WARNING")
195
+ return None
196
+
197
+ os.makedirs(path, exist_ok=True)
198
+
199
+ self._log("Saving model...")
200
+ model_to_save = self.model
201
+
202
+ model_to_save.save_pretrained(path)
203
+ self.tokenizer.save_pretrained(path)
204
+
205
+ self._log(f"Model and Tokenizer saved at {path}")
206
+
207
+ with open(os.path.join(path, "custom_info.json"), "w") as f:
208
+ json.dump({
209
+ "model_id": self._model_id
210
+ }, f)
211
+
212
+ self._log(f"Model custom information saved at {path}")
213
+
214
+ @abstractmethod
215
+ def _set_generation_stopping_tokens(
216
+ self,
217
+ tokens: list[int]
218
+ ) -> None:
219
+ pass
220
+
221
+ @abstractmethod
222
+ def _build_input(
223
+ self,
224
+ *args: Any,
225
+ **kwargs: Any
226
+ ) -> str:
227
+ pass
228
+
229
+ def _tokenize(
230
+ self,
231
+ input_text: str
232
+ ) -> tuple[Tensor, Tensor]:
233
+ if self.model is None or self.tokenizer is None:
234
+ raise MissingEssentialProp("Model or Tokenizer missing")
235
+
236
+ tokenized_input_text: BatchEncoding = self.tokenizer(
237
+ input_text,
238
+ return_tensors="pt"
239
+ ).to(self.model.device)
240
+
241
+ input_ids = tokenized_input_text["input_ids"]
242
+ input_ids = cast(Tensor, input_ids)
243
+ attention_mask = tokenized_input_text["attention_mask"]
244
+ attention_mask = cast(Tensor, attention_mask)
245
+ return (input_ids, attention_mask)
246
+
247
+ def _tokenize_for_dapt(
248
+ self,
249
+ input_text: str
250
+ ) -> tuple:
251
+ if self.model is None or self.tokenizer is None:
252
+ raise MissingEssentialProp("Model or Tokenizer missing")
253
+
254
+ tokenized = self.tokenizer(
255
+ input_text
256
+ )
257
+
258
+ input_ids = tokenized["input_ids"]
259
+ attention_mask = tokenized["attention_mask"]
260
+
261
+ return input_ids, attention_mask
262
+
263
+ def _tokenize_dataset_for_dapt(
264
+ self,
265
+ dataset: list[str]
266
+ ) -> Dataset:
267
+ tokenized = []
268
+ for input_text in dataset:
269
+ tokenized_input = self._tokenize_for_dapt(input_text)
270
+ if tokenized_input:
271
+ input_ids, attention_mask = tokenized_input
272
+ tokenized.append({
273
+ "input_ids": input_ids,
274
+ "attention_mask": attention_mask
275
+ })
276
+ return Dataset.from_list(tokenized)
277
+
278
+ def _promptfy_dataset_for_dapt(
279
+ self,
280
+ dataset: list[dict[str, str | None]]
281
+ ) -> list[str]:
282
+ output = []
283
+ for data in dataset:
284
+ complete_input = self._build_input(
285
+ **{field: data.get(field) for field in self.question_fields + self.answer_fields}
286
+ )
287
+ output.append(complete_input)
288
+
289
+ return output
290
+
291
+ def dapt(
292
+ self,
293
+ train_dataset: list[Any],
294
+ params: TrainParams | None = None,
295
+ eval_dataset: list[Any] | None = None,
296
+ save_at_end = True,
297
+ save_path: str | None = None
298
+ ) -> None:
299
+ if not self.model:
300
+ self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
301
+ return None
302
+ if not self.tokenizer:
303
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
304
+ return None
305
+
306
+ self._log("Starting DAPT")
307
+
308
+ if self.model_is_quantized:
309
+ self._log("Cannot DAPT a quantized model.", "WARNING")
310
+ return None
311
+
312
+ if params is None:
313
+ params = TrainParams()
314
+
315
+ training_arguments = SFTConfig(
316
+ num_train_epochs=params.epochs,
317
+ learning_rate=params.lr,
318
+ gradient_accumulation_steps=params.gradient_accumulation,
319
+ warmup_ratio=params.warmup_ratio,
320
+ lr_scheduler_type="cosine_with_min_lr",
321
+ lr_scheduler_kwargs={"min_lr_rate": 0.1},
322
+ output_dir=None,
323
+ save_strategy="no",
324
+ logging_steps=params.logging_steps
325
+ )
326
+
327
+ if self.seed is not None:
328
+ training_arguments.seed = self.seed
329
+
330
+ processed_train_dataset = self._promptfy_dataset_for_dapt(train_dataset)
331
+ tokenized_train_dataset = self._tokenize_dataset_for_dapt(processed_train_dataset)
332
+
333
+ tokenized_eval_dataset = None
334
+ if eval_dataset:
335
+ processed_eval_dataset = self._promptfy_dataset_for_dapt(eval_dataset)
336
+ tokenized_eval_dataset = self._tokenize_dataset_for_dapt(processed_eval_dataset)
337
+
338
+ log_callback = LogCollectorCallback()
339
+
340
+ trainer = SFTTrainer(
341
+ model=self.model,
342
+ train_dataset=tokenized_train_dataset,
343
+ eval_dataset=tokenized_eval_dataset,
344
+ args=training_arguments,
345
+ callbacks=[log_callback]
346
+ )
347
+
348
+ trainer.train()
349
+
350
+ if save_at_end and save_path:
351
+ self.save_checkpoint(
352
+ path=save_path
353
+ )
354
+
355
+ self._log("Finished DAPT")
356
+
357
+ def _tokenize_for_fine_tune(
358
+ self,
359
+ input_text: str,
360
+ expected_text: str
361
+ ) -> tuple[Tensor, Tensor, Tensor]:
362
+ if self.model is None or self.tokenizer is None:
363
+ raise MissingEssentialProp("Model or Tokenizer missing")
364
+
365
+ encoded_input = self.tokenizer(
366
+ input_text
367
+ )
368
+ encoded_expected = self.tokenizer(
369
+ expected_text
370
+ )
371
+
372
+ input_ids = torch.tensor(encoded_expected["input_ids"], dtype=torch.long)
373
+ attention_mask = torch.tensor(encoded_expected["attention_mask"], dtype=torch.bool)
374
+
375
+ labels = torch.full_like(input_ids, -100)
376
+
377
+ start = len(cast(list, encoded_input["input_ids"]))
378
+
379
+ labels[start:] = input_ids[start:]
380
+
381
+ return input_ids, attention_mask, labels
382
+
383
+ def _tokenize_dataset_for_fine_tune(
384
+ self,
385
+ dataset: list[dict[Literal["partial", "complete"], str]]
386
+ ) -> Dataset:
387
+ tokenized = []
388
+
389
+ for data in dataset:
390
+ tokenized_input = self._tokenize_for_fine_tune(
391
+ input_text=data["partial"],
392
+ expected_text=data["complete"]
393
+ )
394
+
395
+ input_ids, attention_mask, labels = tokenized_input
396
+ tokenized.append({
397
+ "input_ids": input_ids,
398
+ "attention_mask": attention_mask,
399
+ "labels": labels
400
+ })
401
+ return Dataset.from_list(tokenized)
402
+
403
+ def _build_input_for_fine_tune(
404
+ self,
405
+ input: dict
406
+ ) -> dict[Literal["partial", "complete"], str]:
407
+ if not self.tokenizer:
408
+ raise MissingEssentialProp("Could not find tokenizer.")
409
+
410
+ partial = self._build_input(**{k: input[k] for k in self.question_fields if k in input})
411
+
412
+ complete = self._build_input(**{k: input[k] for k in self.question_fields + self.answer_fields if k in input})
413
+
414
+ return {
415
+ "partial": partial,
416
+ "complete": complete
417
+ }
418
+
419
+ def _promptfy_dataset_for_fine_tune(
420
+ self,
421
+ dataset: list[Any]
422
+ ) -> list[dict[Literal["partial", "complete"], str]]:
423
+ output = []
424
+ for data in dataset:
425
+ builded_inputs = self._build_input_for_fine_tune(
426
+ input=data
427
+ )
428
+ output.append(builded_inputs)
429
+
430
+ return output
431
+
432
+ def fine_tune(
433
+ self,
434
+ train_dataset: list[Any],
435
+ params: TrainParams | None = None,
436
+ eval_dataset: list[Any] | None = None,
437
+ save_at_end = True,
438
+ save_path: str | None = None
439
+ ) -> None:
440
+ if not self.model:
441
+ self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
442
+ return None
443
+ if not self.tokenizer:
444
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
445
+ return None
446
+
447
+ self._log("Starting fine-tune")
448
+
449
+ if self.model_is_quantized:
450
+ self._log("Cannot fine-tune a quantized model.", "WARNING")
451
+ return None
452
+
453
+ if params is None:
454
+ params = TrainParams()
455
+
456
+ training_arguments = SFTConfig(
457
+ learning_rate=params.lr,
458
+ gradient_checkpointing=True,
459
+ num_train_epochs=params.epochs,
460
+ gradient_accumulation_steps=params.gradient_accumulation,
461
+ warmup_ratio=params.warmup_ratio,
462
+ lr_scheduler_type="cosine_with_min_lr",
463
+ lr_scheduler_kwargs={"min_lr_rate": 0.1},
464
+ output_dir=None,
465
+ save_strategy="no",
466
+ logging_steps=params.logging_steps
467
+ )
468
+
469
+ if self.seed is not None:
470
+ training_arguments.seed = self.seed
471
+
472
+ processed_train_dataset = self._promptfy_dataset_for_fine_tune(train_dataset)
473
+ tokenized_train_dataset = self._tokenize_dataset_for_fine_tune(processed_train_dataset)
474
+
475
+ tokenized_eval_dataset = None
476
+ if eval_dataset:
477
+ processed_eval_dataset = self._promptfy_dataset_for_fine_tune(eval_dataset)
478
+ tokenized_eval_dataset = self._tokenize_dataset_for_fine_tune(processed_eval_dataset)
479
+
480
+ log_callback = LogCollectorCallback()
481
+
482
+ trainer = SFTTrainer(
483
+ model=self.model,
484
+ train_dataset=tokenized_train_dataset,
485
+ eval_dataset=tokenized_eval_dataset,
486
+ args=training_arguments,
487
+ callbacks=[log_callback]
488
+ )
489
+
490
+ trainer.train()
491
+
492
+ if save_at_end and save_path:
493
+ self.save_checkpoint(
494
+ path=save_path
495
+ )
496
+
497
+ self._log("Finished fine-tune")
498
+
499
+ @abstractmethod
500
+ def generate(
501
+ self,
502
+ input: Any,
503
+ params: GenerationParams | None = None
504
+ ) -> str | None:
505
+ pass
506
+
507
+ def unload_model(self) -> None:
508
+ try:
509
+ self._log("Trying to reset model...")
510
+ del self.model
511
+ self.model = None
512
+ self.model_is_quantized = None
513
+ self.process_id = None
514
+ self._model_id = None
515
+ self._log("Reset successfully.")
516
+ except Exception as e:
517
+ self._log("Couldn't reset model...", "ERROR")
518
+ self._log(f"{str(e)}", "DEBUG")
519
+
520
+ def set_seed(self, seed: int) -> None:
521
+ self._log(f"Setting seed value {seed}")
522
+ self._set_seed(seed)
523
+ self._log(f"Seed setted")
524
+
525
+ def __del__(self) -> None:
526
+ self.unload_model()
527
+ del self.tokenizer
File without changes
@@ -0,0 +1,21 @@
1
+ from typing import Any, Optional
2
+
3
+ from transformers.trainer_callback import (TrainerCallback, TrainerControl,
4
+ TrainerState)
5
+ from transformers.training_args import TrainingArguments
6
+
7
+
8
+ class LogCollectorCallback(TrainerCallback):
9
+ def __init__(self):
10
+ self.logs: list[dict] = []
11
+
12
+ def on_log(
13
+ self,
14
+ args: TrainingArguments,
15
+ state: TrainerState,
16
+ control: TrainerControl,
17
+ logs: Optional[dict[str, Any]] = None,
18
+ **kwargs
19
+ ):
20
+ if logs is not None:
21
+ self.logs.append(logs.copy())
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from transformers import StoppingCriteria
3
+
4
+
5
+ class StopOnToken(StoppingCriteria):
6
+ def __init__(
7
+ self,
8
+ stop_token_ids: list[int],
9
+ ) -> None:
10
+ self.stop_token_ids = torch.tensor(stop_token_ids)
11
+
12
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
13
+ last_token = input_ids[0, -1]
14
+ stop_tokens = self.stop_token_ids.to(input_ids.device)
15
+
16
+ return (last_token == stop_tokens).any() # type: ignore