llmflowstack 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmflowstack/__init__.py +19 -0
- llmflowstack/base/__init__.py +0 -0
- llmflowstack/base/base.py +527 -0
- llmflowstack/callbacks/__init__.py +0 -0
- llmflowstack/callbacks/log_collector.py +21 -0
- llmflowstack/callbacks/stop_on_token.py +16 -0
- llmflowstack/models/GPT_OSS.py +265 -0
- llmflowstack/models/Gemma.py +247 -0
- llmflowstack/models/LLaMA3.py +213 -0
- llmflowstack/models/__init__.py +9 -0
- llmflowstack/rag/__iinit__.py +5 -0
- llmflowstack/rag/pipeline.py +114 -0
- llmflowstack/schemas/__init__.py +9 -0
- llmflowstack/schemas/params.py +39 -0
- llmflowstack/utils/__init__.py +11 -0
- llmflowstack/utils/evaluation_methods.py +92 -0
- llmflowstack/utils/exceptions.py +2 -0
- llmflowstack/utils/generation_utils.py +30 -0
- llmflowstack-1.0.0.dist-info/METADATA +229 -0
- llmflowstack-1.0.0.dist-info/RECORD +22 -0
- llmflowstack-1.0.0.dist-info/WHEEL +4 -0
- llmflowstack-1.0.0.dist-info/licenses/LICENSE +21 -0
llmflowstack/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .models.Gemma import Gemma
|
|
2
|
+
from .models.GPT_OSS import GPT_OSS
|
|
3
|
+
from .models.LLaMA3 import LLaMA3
|
|
4
|
+
from .rag.pipeline import RAGPipeline
|
|
5
|
+
from .schemas.params import (GenerationBeamsParams, GenerationParams,
|
|
6
|
+
GenerationSampleParams, TrainParams)
|
|
7
|
+
from .utils.evaluation_methods import text_evaluation
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Gemma",
|
|
11
|
+
"LLaMA3",
|
|
12
|
+
"GPT_OSS",
|
|
13
|
+
"RAGPipeline",
|
|
14
|
+
"GenerationBeamsParams",
|
|
15
|
+
"GenerationParams",
|
|
16
|
+
"GenerationSampleParams",
|
|
17
|
+
"TrainParams",
|
|
18
|
+
"text_evaluation"
|
|
19
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Literal, cast
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from colorama import Fore, Style, init
|
|
12
|
+
from datasets import Dataset
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
15
|
+
from transformers.tokenization_utils_base import BatchEncoding
|
|
16
|
+
from trl.trainer.sft_config import SFTConfig
|
|
17
|
+
from trl.trainer.sft_trainer import SFTTrainer
|
|
18
|
+
|
|
19
|
+
from llmflowstack.callbacks.log_collector import LogCollectorCallback
|
|
20
|
+
from llmflowstack.schemas.params import GenerationParams, TrainParams
|
|
21
|
+
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseModel(ABC):
|
|
25
|
+
model = None
|
|
26
|
+
tokenizer = None
|
|
27
|
+
_model_id = None
|
|
28
|
+
model_is_quantized = None
|
|
29
|
+
seed = None
|
|
30
|
+
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO"
|
|
31
|
+
stop_token_ids = []
|
|
32
|
+
question_fields = []
|
|
33
|
+
answer_fields = []
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
checkpoint: str | None = None,
|
|
38
|
+
quantization: Literal["8bit", "4bit"] | bool | None = None,
|
|
39
|
+
seed: int | None = None,
|
|
40
|
+
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
41
|
+
) -> None:
|
|
42
|
+
if not self.question_fields or not self.answer_fields:
|
|
43
|
+
raise NotImplementedError("Subclasses must define question_fields and answer_fields.")
|
|
44
|
+
|
|
45
|
+
init(autoreset=True)
|
|
46
|
+
if seed:
|
|
47
|
+
self._set_seed(seed)
|
|
48
|
+
|
|
49
|
+
self._base_model = checkpoint
|
|
50
|
+
|
|
51
|
+
self._set_logger(log_level)
|
|
52
|
+
self.log_level = log_level
|
|
53
|
+
|
|
54
|
+
self.tokenizer: PreTrainedTokenizerBase | None = None
|
|
55
|
+
|
|
56
|
+
if checkpoint:
|
|
57
|
+
self._checkpoint = checkpoint
|
|
58
|
+
self.load_checkpoint(
|
|
59
|
+
checkpoint=checkpoint,
|
|
60
|
+
quantization=quantization
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def _load_model(
|
|
65
|
+
self,
|
|
66
|
+
checkpoint: str,
|
|
67
|
+
quantization: Literal["8bit", "4bit"] | bool | None = None
|
|
68
|
+
) -> None:
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
def _load_tokenizer(self, checkpoint: str) -> None:
|
|
72
|
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
|
73
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
74
|
+
tokenizer.add_eos_token = True
|
|
75
|
+
tokenizer.padding_side = "right"
|
|
76
|
+
|
|
77
|
+
self.tokenizer = tokenizer
|
|
78
|
+
|
|
79
|
+
def load_checkpoint(
|
|
80
|
+
self,
|
|
81
|
+
checkpoint: str,
|
|
82
|
+
quantization: Literal["8bit", "4bit"] | bool | None = None
|
|
83
|
+
) -> None:
|
|
84
|
+
if self.model:
|
|
85
|
+
self._log("A model is already loaded. Attempting to reset it.", "WARNING")
|
|
86
|
+
self.unload_model()
|
|
87
|
+
|
|
88
|
+
self._log(f"Loading model on '{checkpoint}'")
|
|
89
|
+
|
|
90
|
+
self._load_tokenizer(checkpoint)
|
|
91
|
+
self._load_model(
|
|
92
|
+
checkpoint=checkpoint,
|
|
93
|
+
quantization=quantization
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self._log("Model & Tokenizer loaded")
|
|
97
|
+
|
|
98
|
+
if quantization:
|
|
99
|
+
self.model_is_quantized = True
|
|
100
|
+
|
|
101
|
+
if not self._model_id:
|
|
102
|
+
self._create_model_id()
|
|
103
|
+
|
|
104
|
+
stop_tokens = []
|
|
105
|
+
pad_token_id = getattr(self.tokenizer, "pad_token_id", None)
|
|
106
|
+
if pad_token_id:
|
|
107
|
+
stop_tokens.append(pad_token_id)
|
|
108
|
+
eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
|
|
109
|
+
if eos_token_id:
|
|
110
|
+
stop_tokens.append(eos_token_id)
|
|
111
|
+
|
|
112
|
+
self._set_generation_stopping_tokens(stop_tokens)
|
|
113
|
+
self.stop_token_ids = list(set(self.stop_token_ids))
|
|
114
|
+
|
|
115
|
+
def from_pretrained(
|
|
116
|
+
self,
|
|
117
|
+
checkpoint: str,
|
|
118
|
+
quantization: Literal["8bit", "4bit"] | bool | None = None
|
|
119
|
+
) -> None:
|
|
120
|
+
self.load_checkpoint(
|
|
121
|
+
checkpoint=checkpoint,
|
|
122
|
+
quantization=quantization
|
|
123
|
+
)
|
|
124
|
+
with open(os.path.join(checkpoint, "custom_info.json"), "r") as f:
|
|
125
|
+
data = json.load(f)
|
|
126
|
+
self._model_id = data.get("model_id", None)
|
|
127
|
+
|
|
128
|
+
def _create_model_id(
|
|
129
|
+
self
|
|
130
|
+
) -> None:
|
|
131
|
+
self._model_id = uuid4()
|
|
132
|
+
|
|
133
|
+
def _set_logger(
|
|
134
|
+
self,
|
|
135
|
+
level: str
|
|
136
|
+
) -> None:
|
|
137
|
+
level_map = {
|
|
138
|
+
"DEBUG": logging.DEBUG,
|
|
139
|
+
"INFO": logging.INFO,
|
|
140
|
+
"WARNING": logging.WARNING,
|
|
141
|
+
"ERROR": logging.ERROR,
|
|
142
|
+
}
|
|
143
|
+
numeric_level = level_map.get(level.upper(), logging.INFO)
|
|
144
|
+
|
|
145
|
+
logging.basicConfig(
|
|
146
|
+
level=numeric_level,
|
|
147
|
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
|
148
|
+
)
|
|
149
|
+
self.logger = logging.getLogger(__name__)
|
|
150
|
+
|
|
151
|
+
def _log(
|
|
152
|
+
self,
|
|
153
|
+
info: str,
|
|
154
|
+
level: Literal["INFO", "WARNING", "ERROR", "DEBUG"] = "INFO"
|
|
155
|
+
) -> None:
|
|
156
|
+
if level == "INFO":
|
|
157
|
+
colored_msg = f"{Fore.GREEN}{info}{Style.RESET_ALL}"
|
|
158
|
+
self.logger.info(colored_msg)
|
|
159
|
+
elif level == "WARNING":
|
|
160
|
+
colored_msg = f"{Fore.YELLOW}{info}{Style.RESET_ALL}"
|
|
161
|
+
self.logger.warning(colored_msg)
|
|
162
|
+
elif level == "ERROR":
|
|
163
|
+
colored_msg = f"{Fore.RED}{info}{Style.RESET_ALL}"
|
|
164
|
+
self.logger.error(colored_msg)
|
|
165
|
+
elif level == "DEBUG":
|
|
166
|
+
colored_msg = f"{Fore.BLUE}{info}{Style.RESET_ALL}"
|
|
167
|
+
self.logger.debug(colored_msg)
|
|
168
|
+
|
|
169
|
+
def _set_seed(
|
|
170
|
+
self,
|
|
171
|
+
seed: int
|
|
172
|
+
) -> None:
|
|
173
|
+
self.seed = seed
|
|
174
|
+
random.seed(seed)
|
|
175
|
+
np.random.seed(seed)
|
|
176
|
+
torch.manual_seed(seed)
|
|
177
|
+
torch.cuda.manual_seed(seed)
|
|
178
|
+
|
|
179
|
+
torch.backends.cudnn.deterministic = True
|
|
180
|
+
torch.backends.cudnn.benchmark = False
|
|
181
|
+
torch.use_deterministic_algorithms(True)
|
|
182
|
+
|
|
183
|
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
|
184
|
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
|
185
|
+
|
|
186
|
+
def save_checkpoint(
|
|
187
|
+
self,
|
|
188
|
+
path: str
|
|
189
|
+
) -> None:
|
|
190
|
+
if not self.model:
|
|
191
|
+
self._log("No model to save.", "WARNING")
|
|
192
|
+
return None
|
|
193
|
+
if not self.tokenizer:
|
|
194
|
+
self._log("No tokenizer to save.", "WARNING")
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
os.makedirs(path, exist_ok=True)
|
|
198
|
+
|
|
199
|
+
self._log("Saving model...")
|
|
200
|
+
model_to_save = self.model
|
|
201
|
+
|
|
202
|
+
model_to_save.save_pretrained(path)
|
|
203
|
+
self.tokenizer.save_pretrained(path)
|
|
204
|
+
|
|
205
|
+
self._log(f"Model and Tokenizer saved at {path}")
|
|
206
|
+
|
|
207
|
+
with open(os.path.join(path, "custom_info.json"), "w") as f:
|
|
208
|
+
json.dump({
|
|
209
|
+
"model_id": self._model_id
|
|
210
|
+
}, f)
|
|
211
|
+
|
|
212
|
+
self._log(f"Model custom information saved at {path}")
|
|
213
|
+
|
|
214
|
+
@abstractmethod
|
|
215
|
+
def _set_generation_stopping_tokens(
|
|
216
|
+
self,
|
|
217
|
+
tokens: list[int]
|
|
218
|
+
) -> None:
|
|
219
|
+
pass
|
|
220
|
+
|
|
221
|
+
@abstractmethod
|
|
222
|
+
def _build_input(
|
|
223
|
+
self,
|
|
224
|
+
*args: Any,
|
|
225
|
+
**kwargs: Any
|
|
226
|
+
) -> str:
|
|
227
|
+
pass
|
|
228
|
+
|
|
229
|
+
def _tokenize(
|
|
230
|
+
self,
|
|
231
|
+
input_text: str
|
|
232
|
+
) -> tuple[Tensor, Tensor]:
|
|
233
|
+
if self.model is None or self.tokenizer is None:
|
|
234
|
+
raise MissingEssentialProp("Model or Tokenizer missing")
|
|
235
|
+
|
|
236
|
+
tokenized_input_text: BatchEncoding = self.tokenizer(
|
|
237
|
+
input_text,
|
|
238
|
+
return_tensors="pt"
|
|
239
|
+
).to(self.model.device)
|
|
240
|
+
|
|
241
|
+
input_ids = tokenized_input_text["input_ids"]
|
|
242
|
+
input_ids = cast(Tensor, input_ids)
|
|
243
|
+
attention_mask = tokenized_input_text["attention_mask"]
|
|
244
|
+
attention_mask = cast(Tensor, attention_mask)
|
|
245
|
+
return (input_ids, attention_mask)
|
|
246
|
+
|
|
247
|
+
def _tokenize_for_dapt(
|
|
248
|
+
self,
|
|
249
|
+
input_text: str
|
|
250
|
+
) -> tuple:
|
|
251
|
+
if self.model is None or self.tokenizer is None:
|
|
252
|
+
raise MissingEssentialProp("Model or Tokenizer missing")
|
|
253
|
+
|
|
254
|
+
tokenized = self.tokenizer(
|
|
255
|
+
input_text
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
input_ids = tokenized["input_ids"]
|
|
259
|
+
attention_mask = tokenized["attention_mask"]
|
|
260
|
+
|
|
261
|
+
return input_ids, attention_mask
|
|
262
|
+
|
|
263
|
+
def _tokenize_dataset_for_dapt(
|
|
264
|
+
self,
|
|
265
|
+
dataset: list[str]
|
|
266
|
+
) -> Dataset:
|
|
267
|
+
tokenized = []
|
|
268
|
+
for input_text in dataset:
|
|
269
|
+
tokenized_input = self._tokenize_for_dapt(input_text)
|
|
270
|
+
if tokenized_input:
|
|
271
|
+
input_ids, attention_mask = tokenized_input
|
|
272
|
+
tokenized.append({
|
|
273
|
+
"input_ids": input_ids,
|
|
274
|
+
"attention_mask": attention_mask
|
|
275
|
+
})
|
|
276
|
+
return Dataset.from_list(tokenized)
|
|
277
|
+
|
|
278
|
+
def _promptfy_dataset_for_dapt(
|
|
279
|
+
self,
|
|
280
|
+
dataset: list[dict[str, str | None]]
|
|
281
|
+
) -> list[str]:
|
|
282
|
+
output = []
|
|
283
|
+
for data in dataset:
|
|
284
|
+
complete_input = self._build_input(
|
|
285
|
+
**{field: data.get(field) for field in self.question_fields + self.answer_fields}
|
|
286
|
+
)
|
|
287
|
+
output.append(complete_input)
|
|
288
|
+
|
|
289
|
+
return output
|
|
290
|
+
|
|
291
|
+
def dapt(
|
|
292
|
+
self,
|
|
293
|
+
train_dataset: list[Any],
|
|
294
|
+
params: TrainParams | None = None,
|
|
295
|
+
eval_dataset: list[Any] | None = None,
|
|
296
|
+
save_at_end = True,
|
|
297
|
+
save_path: str | None = None
|
|
298
|
+
) -> None:
|
|
299
|
+
if not self.model:
|
|
300
|
+
self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
|
|
301
|
+
return None
|
|
302
|
+
if not self.tokenizer:
|
|
303
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
|
|
304
|
+
return None
|
|
305
|
+
|
|
306
|
+
self._log("Starting DAPT")
|
|
307
|
+
|
|
308
|
+
if self.model_is_quantized:
|
|
309
|
+
self._log("Cannot DAPT a quantized model.", "WARNING")
|
|
310
|
+
return None
|
|
311
|
+
|
|
312
|
+
if params is None:
|
|
313
|
+
params = TrainParams()
|
|
314
|
+
|
|
315
|
+
training_arguments = SFTConfig(
|
|
316
|
+
num_train_epochs=params.epochs,
|
|
317
|
+
learning_rate=params.lr,
|
|
318
|
+
gradient_accumulation_steps=params.gradient_accumulation,
|
|
319
|
+
warmup_ratio=params.warmup_ratio,
|
|
320
|
+
lr_scheduler_type="cosine_with_min_lr",
|
|
321
|
+
lr_scheduler_kwargs={"min_lr_rate": 0.1},
|
|
322
|
+
output_dir=None,
|
|
323
|
+
save_strategy="no",
|
|
324
|
+
logging_steps=params.logging_steps
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if self.seed is not None:
|
|
328
|
+
training_arguments.seed = self.seed
|
|
329
|
+
|
|
330
|
+
processed_train_dataset = self._promptfy_dataset_for_dapt(train_dataset)
|
|
331
|
+
tokenized_train_dataset = self._tokenize_dataset_for_dapt(processed_train_dataset)
|
|
332
|
+
|
|
333
|
+
tokenized_eval_dataset = None
|
|
334
|
+
if eval_dataset:
|
|
335
|
+
processed_eval_dataset = self._promptfy_dataset_for_dapt(eval_dataset)
|
|
336
|
+
tokenized_eval_dataset = self._tokenize_dataset_for_dapt(processed_eval_dataset)
|
|
337
|
+
|
|
338
|
+
log_callback = LogCollectorCallback()
|
|
339
|
+
|
|
340
|
+
trainer = SFTTrainer(
|
|
341
|
+
model=self.model,
|
|
342
|
+
train_dataset=tokenized_train_dataset,
|
|
343
|
+
eval_dataset=tokenized_eval_dataset,
|
|
344
|
+
args=training_arguments,
|
|
345
|
+
callbacks=[log_callback]
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
trainer.train()
|
|
349
|
+
|
|
350
|
+
if save_at_end and save_path:
|
|
351
|
+
self.save_checkpoint(
|
|
352
|
+
path=save_path
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
self._log("Finished DAPT")
|
|
356
|
+
|
|
357
|
+
def _tokenize_for_fine_tune(
|
|
358
|
+
self,
|
|
359
|
+
input_text: str,
|
|
360
|
+
expected_text: str
|
|
361
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
362
|
+
if self.model is None or self.tokenizer is None:
|
|
363
|
+
raise MissingEssentialProp("Model or Tokenizer missing")
|
|
364
|
+
|
|
365
|
+
encoded_input = self.tokenizer(
|
|
366
|
+
input_text
|
|
367
|
+
)
|
|
368
|
+
encoded_expected = self.tokenizer(
|
|
369
|
+
expected_text
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
input_ids = torch.tensor(encoded_expected["input_ids"], dtype=torch.long)
|
|
373
|
+
attention_mask = torch.tensor(encoded_expected["attention_mask"], dtype=torch.bool)
|
|
374
|
+
|
|
375
|
+
labels = torch.full_like(input_ids, -100)
|
|
376
|
+
|
|
377
|
+
start = len(cast(list, encoded_input["input_ids"]))
|
|
378
|
+
|
|
379
|
+
labels[start:] = input_ids[start:]
|
|
380
|
+
|
|
381
|
+
return input_ids, attention_mask, labels
|
|
382
|
+
|
|
383
|
+
def _tokenize_dataset_for_fine_tune(
|
|
384
|
+
self,
|
|
385
|
+
dataset: list[dict[Literal["partial", "complete"], str]]
|
|
386
|
+
) -> Dataset:
|
|
387
|
+
tokenized = []
|
|
388
|
+
|
|
389
|
+
for data in dataset:
|
|
390
|
+
tokenized_input = self._tokenize_for_fine_tune(
|
|
391
|
+
input_text=data["partial"],
|
|
392
|
+
expected_text=data["complete"]
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
input_ids, attention_mask, labels = tokenized_input
|
|
396
|
+
tokenized.append({
|
|
397
|
+
"input_ids": input_ids,
|
|
398
|
+
"attention_mask": attention_mask,
|
|
399
|
+
"labels": labels
|
|
400
|
+
})
|
|
401
|
+
return Dataset.from_list(tokenized)
|
|
402
|
+
|
|
403
|
+
def _build_input_for_fine_tune(
|
|
404
|
+
self,
|
|
405
|
+
input: dict
|
|
406
|
+
) -> dict[Literal["partial", "complete"], str]:
|
|
407
|
+
if not self.tokenizer:
|
|
408
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
409
|
+
|
|
410
|
+
partial = self._build_input(**{k: input[k] for k in self.question_fields if k in input})
|
|
411
|
+
|
|
412
|
+
complete = self._build_input(**{k: input[k] for k in self.question_fields + self.answer_fields if k in input})
|
|
413
|
+
|
|
414
|
+
return {
|
|
415
|
+
"partial": partial,
|
|
416
|
+
"complete": complete
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
def _promptfy_dataset_for_fine_tune(
|
|
420
|
+
self,
|
|
421
|
+
dataset: list[Any]
|
|
422
|
+
) -> list[dict[Literal["partial", "complete"], str]]:
|
|
423
|
+
output = []
|
|
424
|
+
for data in dataset:
|
|
425
|
+
builded_inputs = self._build_input_for_fine_tune(
|
|
426
|
+
input=data
|
|
427
|
+
)
|
|
428
|
+
output.append(builded_inputs)
|
|
429
|
+
|
|
430
|
+
return output
|
|
431
|
+
|
|
432
|
+
def fine_tune(
|
|
433
|
+
self,
|
|
434
|
+
train_dataset: list[Any],
|
|
435
|
+
params: TrainParams | None = None,
|
|
436
|
+
eval_dataset: list[Any] | None = None,
|
|
437
|
+
save_at_end = True,
|
|
438
|
+
save_path: str | None = None
|
|
439
|
+
) -> None:
|
|
440
|
+
if not self.model:
|
|
441
|
+
self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
|
|
442
|
+
return None
|
|
443
|
+
if not self.tokenizer:
|
|
444
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
|
|
445
|
+
return None
|
|
446
|
+
|
|
447
|
+
self._log("Starting fine-tune")
|
|
448
|
+
|
|
449
|
+
if self.model_is_quantized:
|
|
450
|
+
self._log("Cannot fine-tune a quantized model.", "WARNING")
|
|
451
|
+
return None
|
|
452
|
+
|
|
453
|
+
if params is None:
|
|
454
|
+
params = TrainParams()
|
|
455
|
+
|
|
456
|
+
training_arguments = SFTConfig(
|
|
457
|
+
learning_rate=params.lr,
|
|
458
|
+
gradient_checkpointing=True,
|
|
459
|
+
num_train_epochs=params.epochs,
|
|
460
|
+
gradient_accumulation_steps=params.gradient_accumulation,
|
|
461
|
+
warmup_ratio=params.warmup_ratio,
|
|
462
|
+
lr_scheduler_type="cosine_with_min_lr",
|
|
463
|
+
lr_scheduler_kwargs={"min_lr_rate": 0.1},
|
|
464
|
+
output_dir=None,
|
|
465
|
+
save_strategy="no",
|
|
466
|
+
logging_steps=params.logging_steps
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if self.seed is not None:
|
|
470
|
+
training_arguments.seed = self.seed
|
|
471
|
+
|
|
472
|
+
processed_train_dataset = self._promptfy_dataset_for_fine_tune(train_dataset)
|
|
473
|
+
tokenized_train_dataset = self._tokenize_dataset_for_fine_tune(processed_train_dataset)
|
|
474
|
+
|
|
475
|
+
tokenized_eval_dataset = None
|
|
476
|
+
if eval_dataset:
|
|
477
|
+
processed_eval_dataset = self._promptfy_dataset_for_fine_tune(eval_dataset)
|
|
478
|
+
tokenized_eval_dataset = self._tokenize_dataset_for_fine_tune(processed_eval_dataset)
|
|
479
|
+
|
|
480
|
+
log_callback = LogCollectorCallback()
|
|
481
|
+
|
|
482
|
+
trainer = SFTTrainer(
|
|
483
|
+
model=self.model,
|
|
484
|
+
train_dataset=tokenized_train_dataset,
|
|
485
|
+
eval_dataset=tokenized_eval_dataset,
|
|
486
|
+
args=training_arguments,
|
|
487
|
+
callbacks=[log_callback]
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
trainer.train()
|
|
491
|
+
|
|
492
|
+
if save_at_end and save_path:
|
|
493
|
+
self.save_checkpoint(
|
|
494
|
+
path=save_path
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
self._log("Finished fine-tune")
|
|
498
|
+
|
|
499
|
+
@abstractmethod
|
|
500
|
+
def generate(
|
|
501
|
+
self,
|
|
502
|
+
input: Any,
|
|
503
|
+
params: GenerationParams | None = None
|
|
504
|
+
) -> str | None:
|
|
505
|
+
pass
|
|
506
|
+
|
|
507
|
+
def unload_model(self) -> None:
|
|
508
|
+
try:
|
|
509
|
+
self._log("Trying to reset model...")
|
|
510
|
+
del self.model
|
|
511
|
+
self.model = None
|
|
512
|
+
self.model_is_quantized = None
|
|
513
|
+
self.process_id = None
|
|
514
|
+
self._model_id = None
|
|
515
|
+
self._log("Reset successfully.")
|
|
516
|
+
except Exception as e:
|
|
517
|
+
self._log("Couldn't reset model...", "ERROR")
|
|
518
|
+
self._log(f"{str(e)}", "DEBUG")
|
|
519
|
+
|
|
520
|
+
def set_seed(self, seed: int) -> None:
|
|
521
|
+
self._log(f"Setting seed value {seed}")
|
|
522
|
+
self._set_seed(seed)
|
|
523
|
+
self._log(f"Seed setted")
|
|
524
|
+
|
|
525
|
+
def __del__(self) -> None:
|
|
526
|
+
self.unload_model()
|
|
527
|
+
del self.tokenizer
|
|
File without changes
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from transformers.trainer_callback import (TrainerCallback, TrainerControl,
|
|
4
|
+
TrainerState)
|
|
5
|
+
from transformers.training_args import TrainingArguments
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LogCollectorCallback(TrainerCallback):
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.logs: list[dict] = []
|
|
11
|
+
|
|
12
|
+
def on_log(
|
|
13
|
+
self,
|
|
14
|
+
args: TrainingArguments,
|
|
15
|
+
state: TrainerState,
|
|
16
|
+
control: TrainerControl,
|
|
17
|
+
logs: Optional[dict[str, Any]] = None,
|
|
18
|
+
**kwargs
|
|
19
|
+
):
|
|
20
|
+
if logs is not None:
|
|
21
|
+
self.logs.append(logs.copy())
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from transformers import StoppingCriteria
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class StopOnToken(StoppingCriteria):
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
stop_token_ids: list[int],
|
|
9
|
+
) -> None:
|
|
10
|
+
self.stop_token_ids = torch.tensor(stop_token_ids)
|
|
11
|
+
|
|
12
|
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
|
13
|
+
last_token = input_ids[0, -1]
|
|
14
|
+
stop_tokens = self.stop_token_ids.to(input_ids.device)
|
|
15
|
+
|
|
16
|
+
return (last_token == stop_tokens).any() # type: ignore
|