llmflowstack 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmflowstack/__init__.py +19 -0
- llmflowstack/base/__init__.py +0 -0
- llmflowstack/base/base.py +527 -0
- llmflowstack/callbacks/__init__.py +0 -0
- llmflowstack/callbacks/log_collector.py +21 -0
- llmflowstack/callbacks/stop_on_token.py +16 -0
- llmflowstack/models/GPT_OSS.py +265 -0
- llmflowstack/models/Gemma.py +247 -0
- llmflowstack/models/LLaMA3.py +213 -0
- llmflowstack/models/__init__.py +9 -0
- llmflowstack/rag/__iinit__.py +5 -0
- llmflowstack/rag/pipeline.py +114 -0
- llmflowstack/schemas/__init__.py +9 -0
- llmflowstack/schemas/params.py +39 -0
- llmflowstack/utils/__init__.py +11 -0
- llmflowstack/utils/evaluation_methods.py +92 -0
- llmflowstack/utils/exceptions.py +2 -0
- llmflowstack/utils/generation_utils.py +30 -0
- llmflowstack-1.0.0.dist-info/METADATA +229 -0
- llmflowstack-1.0.0.dist-info/RECORD +22 -0
- llmflowstack-1.0.0.dist-info/WHEEL +4 -0
- llmflowstack-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
import threading
|
|
3
|
+
from time import time
|
|
4
|
+
from typing import Iterator, Literal, TypedDict, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from transformers import (AutoTokenizer, StoppingCriteriaList,
|
|
8
|
+
TextIteratorStreamer)
|
|
9
|
+
from transformers.models.llama import LlamaForCausalLM
|
|
10
|
+
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
11
|
+
|
|
12
|
+
from llmflowstack.base.base import BaseModel
|
|
13
|
+
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
14
|
+
from llmflowstack.schemas.params import GenerationParams
|
|
15
|
+
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
16
|
+
from llmflowstack.utils.generation_utils import create_generation_params
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LLaMA3Input(TypedDict):
|
|
20
|
+
input_text: str
|
|
21
|
+
expected_answer: str | None
|
|
22
|
+
system_message: str | None
|
|
23
|
+
|
|
24
|
+
class LLaMA3(BaseModel):
|
|
25
|
+
model: LlamaForCausalLM | None = None
|
|
26
|
+
question_fields = ["input_text", "system_message"]
|
|
27
|
+
answer_fields = ["expected_answer"]
|
|
28
|
+
|
|
29
|
+
def _set_generation_stopping_tokens(
|
|
30
|
+
self,
|
|
31
|
+
tokens: list[int]
|
|
32
|
+
) -> None:
|
|
33
|
+
if not self.tokenizer:
|
|
34
|
+
self._log("Could not set stop tokens - generation may not work...", "WARNING")
|
|
35
|
+
return None
|
|
36
|
+
particular_tokens = self.tokenizer.encode("<|eot_id|>")
|
|
37
|
+
self.stop_token_ids = tokens + particular_tokens
|
|
38
|
+
|
|
39
|
+
def _load_model(
|
|
40
|
+
self,
|
|
41
|
+
checkpoint: str,
|
|
42
|
+
quantization: Literal["8bit", "4bit"] | bool | None = None
|
|
43
|
+
) -> None:
|
|
44
|
+
quantization_config = None
|
|
45
|
+
if quantization == "4bit":
|
|
46
|
+
quantization_config = BitsAndBytesConfig(
|
|
47
|
+
load_in_4bit=True
|
|
48
|
+
)
|
|
49
|
+
self.model_is_quantized = True
|
|
50
|
+
if quantization == "8bit":
|
|
51
|
+
quantization_config = BitsAndBytesConfig(
|
|
52
|
+
load_in_8bit=True
|
|
53
|
+
)
|
|
54
|
+
self.model_is_quantized = True
|
|
55
|
+
|
|
56
|
+
self.model = LlamaForCausalLM.from_pretrained(
|
|
57
|
+
checkpoint,
|
|
58
|
+
quantization_config=quantization_config,
|
|
59
|
+
dtype="auto",
|
|
60
|
+
device_map="auto",
|
|
61
|
+
attn_implementation="eager"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _build_input(
|
|
65
|
+
self,
|
|
66
|
+
input_text: str,
|
|
67
|
+
expected_answer: str | None = None,
|
|
68
|
+
system_message: str | None = None
|
|
69
|
+
) -> str:
|
|
70
|
+
if not self.tokenizer:
|
|
71
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
72
|
+
|
|
73
|
+
answer = f"{expected_answer}{self.tokenizer.eos_token}" if expected_answer else ""
|
|
74
|
+
|
|
75
|
+
return textwrap.dedent(
|
|
76
|
+
f"<|start_header_id|>system<|end_header_id|>{system_message or ""}\n"
|
|
77
|
+
f"<|eot_id|><|start_header_id|>user<|end_header_id|>{input_text}\n"
|
|
78
|
+
f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>{answer}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def build_input(
|
|
82
|
+
self,
|
|
83
|
+
input_text: str,
|
|
84
|
+
system_message: str | None = None,
|
|
85
|
+
expected_answer: str | None = None
|
|
86
|
+
) -> LLaMA3Input:
|
|
87
|
+
if not self.tokenizer:
|
|
88
|
+
raise MissingEssentialProp("Could not find tokenizer.")
|
|
89
|
+
|
|
90
|
+
return {
|
|
91
|
+
"input_text": input_text,
|
|
92
|
+
"system_message": system_message,
|
|
93
|
+
"expected_answer": expected_answer
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
def generate(
|
|
97
|
+
self,
|
|
98
|
+
input: LLaMA3Input | str,
|
|
99
|
+
params: GenerationParams | None = None
|
|
100
|
+
) -> str | None:
|
|
101
|
+
if self.model is None or self.tokenizer is None:
|
|
102
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
self.model
|
|
106
|
+
|
|
107
|
+
self._log(f"Processing received input...'")
|
|
108
|
+
|
|
109
|
+
if params is None:
|
|
110
|
+
params = GenerationParams(max_new_tokens=8192)
|
|
111
|
+
elif params.max_new_tokens is None:
|
|
112
|
+
params.max_new_tokens = 8192
|
|
113
|
+
|
|
114
|
+
generation_params = create_generation_params(params)
|
|
115
|
+
self.model.generation_config = generation_params
|
|
116
|
+
|
|
117
|
+
if params:
|
|
118
|
+
generation_params = create_generation_params(params)
|
|
119
|
+
self.model.generation_config = generation_params
|
|
120
|
+
|
|
121
|
+
model_input = None
|
|
122
|
+
if isinstance(input, str):
|
|
123
|
+
model_input = self._build_input(
|
|
124
|
+
input_text=input
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
model_input = self._build_input(
|
|
128
|
+
input_text=input["input_text"],
|
|
129
|
+
system_message=input.get("system_message", "")
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
tokenized_input = self._tokenize(model_input)
|
|
133
|
+
|
|
134
|
+
input_ids, attention_mask = tokenized_input
|
|
135
|
+
|
|
136
|
+
self.model.eval()
|
|
137
|
+
self.model.gradient_checkpointing_disable()
|
|
138
|
+
|
|
139
|
+
start = time()
|
|
140
|
+
|
|
141
|
+
with torch.no_grad():
|
|
142
|
+
outputs = self.model.generate(
|
|
143
|
+
input_ids=input_ids,
|
|
144
|
+
attention_mask=attention_mask,
|
|
145
|
+
use_cache=True,
|
|
146
|
+
eos_token_id=None,
|
|
147
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
end = time()
|
|
151
|
+
total_time = end - start
|
|
152
|
+
|
|
153
|
+
self._log(f"Response generated in {total_time:.4f} seconds")
|
|
154
|
+
|
|
155
|
+
response = outputs[0][input_ids.shape[1]:]
|
|
156
|
+
|
|
157
|
+
return self.tokenizer.decode(response, skip_special_tokens=True)
|
|
158
|
+
|
|
159
|
+
def generate_stream(
|
|
160
|
+
self,
|
|
161
|
+
input: LLaMA3Input | str,
|
|
162
|
+
params: GenerationParams | None = None
|
|
163
|
+
) -> Iterator[str]:
|
|
164
|
+
if self.model is None or self.tokenizer is None:
|
|
165
|
+
self._log("Model or Tokenizer missing", "WARNING")
|
|
166
|
+
if False:
|
|
167
|
+
yield ""
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
if params is None:
|
|
171
|
+
params = GenerationParams(max_new_tokens=8192)
|
|
172
|
+
elif params.max_new_tokens is None:
|
|
173
|
+
params.max_new_tokens = 8192
|
|
174
|
+
|
|
175
|
+
generation_params = create_generation_params(params)
|
|
176
|
+
self.model.generation_config = generation_params
|
|
177
|
+
|
|
178
|
+
if isinstance(input, str):
|
|
179
|
+
model_input = self._build_input(
|
|
180
|
+
input_text=input
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
model_input = self._build_input(
|
|
184
|
+
input_text=input["input_text"],
|
|
185
|
+
system_message=input.get("system_message")
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
tokenized_input = self._tokenize(model_input)
|
|
189
|
+
input_ids, attention_mask = tokenized_input
|
|
190
|
+
|
|
191
|
+
streamer = TextIteratorStreamer(
|
|
192
|
+
cast(AutoTokenizer, self.tokenizer),
|
|
193
|
+
skip_prompt=True,
|
|
194
|
+
skip_special_tokens=True
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def _generate() -> None:
|
|
198
|
+
assert self.model is not None
|
|
199
|
+
with torch.no_grad():
|
|
200
|
+
self.model.generate(
|
|
201
|
+
input_ids=input_ids,
|
|
202
|
+
attention_mask=attention_mask,
|
|
203
|
+
use_cache=True,
|
|
204
|
+
eos_token_id=None,
|
|
205
|
+
streamer=streamer,
|
|
206
|
+
stopping_criteria=StoppingCriteriaList([StopOnToken(self.stop_token_ids)])
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
thread = threading.Thread(target=_generate)
|
|
210
|
+
thread.start()
|
|
211
|
+
|
|
212
|
+
for new_text in streamer:
|
|
213
|
+
yield new_text
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
|
|
3
|
+
from langchain_chroma import Chroma
|
|
4
|
+
from langchain_core.documents import Document
|
|
5
|
+
from langchain_core.embeddings import Embeddings
|
|
6
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
7
|
+
from sentence_transformers import SentenceTransformer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EncoderWrapper(Embeddings):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
model: SentenceTransformer
|
|
14
|
+
) -> None:
|
|
15
|
+
self.model = model
|
|
16
|
+
|
|
17
|
+
def embed_documents(
|
|
18
|
+
self,
|
|
19
|
+
texts: list[str]
|
|
20
|
+
) -> list[list[float]]:
|
|
21
|
+
return self.model.encode(texts, task="retrieval", show_progress_bar=True).tolist()
|
|
22
|
+
|
|
23
|
+
def embed_query(
|
|
24
|
+
self,
|
|
25
|
+
text: str
|
|
26
|
+
) -> list[float]:
|
|
27
|
+
return self.model.encode(text, task="retrieval", show_progress_bar=True).tolist()
|
|
28
|
+
|
|
29
|
+
class RAGPipeline:
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
checkpoint: str,
|
|
33
|
+
collection_name: str = "rag_memory",
|
|
34
|
+
persist_directory: str = "./chroma_store",
|
|
35
|
+
chunk_size: int = 1000,
|
|
36
|
+
chunk_overlap: int = 200
|
|
37
|
+
) -> None:
|
|
38
|
+
|
|
39
|
+
self.encoder = SentenceTransformer(checkpoint, trust_remote_code=True)
|
|
40
|
+
|
|
41
|
+
self.vector_store = Chroma(
|
|
42
|
+
collection_name=collection_name,
|
|
43
|
+
embedding_function=EncoderWrapper(self.encoder),
|
|
44
|
+
persist_directory=persist_directory
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self.splitter = RecursiveCharacterTextSplitter(
|
|
48
|
+
chunk_size=chunk_size,
|
|
49
|
+
chunk_overlap=chunk_overlap,
|
|
50
|
+
add_start_index=True,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def index_documents(
|
|
54
|
+
self,
|
|
55
|
+
docs: list[Document],
|
|
56
|
+
ids: list[str]
|
|
57
|
+
) -> None:
|
|
58
|
+
splits = self.splitter.split_documents(docs)
|
|
59
|
+
split_ids = [f"{ids[0]}_{i}" for i in range(len(splits))]
|
|
60
|
+
self.vector_store.add_documents(splits, ids=split_ids)
|
|
61
|
+
|
|
62
|
+
def create(
|
|
63
|
+
self,
|
|
64
|
+
information: str,
|
|
65
|
+
other_info: dict[str, str] = {},
|
|
66
|
+
doc_id: str | None = None,
|
|
67
|
+
should_index: bool = True
|
|
68
|
+
) -> Document:
|
|
69
|
+
if doc_id is None:
|
|
70
|
+
doc_id = str(uuid.uuid4())
|
|
71
|
+
|
|
72
|
+
doc = Document(
|
|
73
|
+
page_content=information,
|
|
74
|
+
metadata={"id": doc_id, **other_info}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if should_index:
|
|
78
|
+
self.index_documents([doc], ids=[doc_id])
|
|
79
|
+
|
|
80
|
+
return doc
|
|
81
|
+
|
|
82
|
+
def update(
|
|
83
|
+
self,
|
|
84
|
+
doc_id: str,
|
|
85
|
+
new_information: str,
|
|
86
|
+
other_info: dict[str, str] = {}
|
|
87
|
+
) -> Document:
|
|
88
|
+
self.vector_store.delete(ids=[doc_id])
|
|
89
|
+
|
|
90
|
+
return self.create(
|
|
91
|
+
information=new_information,
|
|
92
|
+
other_info=other_info,
|
|
93
|
+
doc_id=doc_id
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def delete(
|
|
97
|
+
self, doc_id: str
|
|
98
|
+
) -> None:
|
|
99
|
+
self.vector_store.delete(ids=[doc_id])
|
|
100
|
+
|
|
101
|
+
def query(
|
|
102
|
+
self,
|
|
103
|
+
query: str,
|
|
104
|
+
k: int = 4,
|
|
105
|
+
category: str | None = None
|
|
106
|
+
) -> str:
|
|
107
|
+
if category:
|
|
108
|
+
docs = self.vector_store.similarity_search(
|
|
109
|
+
query, k=k, filter={"category": category}
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
docs = self.vector_store.similarity_search(query, k=k)
|
|
113
|
+
|
|
114
|
+
return "\n\n".join(doc.page_content for doc in docs)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from transformers import TextIteratorStreamer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class TrainParams:
|
|
9
|
+
batch_size: int = 1
|
|
10
|
+
gradient_accumulation: int = 8
|
|
11
|
+
epochs: int = 1
|
|
12
|
+
warmup_ratio: float = 0.0
|
|
13
|
+
lr: float = 2e-5
|
|
14
|
+
optim: Literal[
|
|
15
|
+
"adamw_torch",
|
|
16
|
+
"adamw_torch_fused",
|
|
17
|
+
"sgd"
|
|
18
|
+
] = "adamw_torch"
|
|
19
|
+
logging_steps=1
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class GenerationBeamsParams:
|
|
23
|
+
num_beams: int | None = None
|
|
24
|
+
length_penalty: float | None = None
|
|
25
|
+
early_stopping: bool | None = None
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class GenerationSampleParams:
|
|
29
|
+
temperature: float | None = None
|
|
30
|
+
top_p: float | None = None
|
|
31
|
+
typical_p: float | None = None
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class GenerationParams:
|
|
35
|
+
max_new_tokens: int | None = None
|
|
36
|
+
repetition_penalty: float | None = None
|
|
37
|
+
sample: GenerationSampleParams = field(default_factory=GenerationSampleParams)
|
|
38
|
+
beams: GenerationBeamsParams = field(default_factory=GenerationBeamsParams)
|
|
39
|
+
streamer: TextIteratorStreamer | None = None
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .evaluation_methods import (bert_score_evaluation,
|
|
2
|
+
cosine_similarity_evaluation,
|
|
3
|
+
rouge_evaluation, text_evaluation)
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"bert_score_evaluation",
|
|
7
|
+
"cosine_similarity_evaluation",
|
|
8
|
+
"rouge_evaluation",
|
|
9
|
+
"evaluation_methods",
|
|
10
|
+
"text_evaluation"
|
|
11
|
+
]
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from evaluate import load
|
|
2
|
+
from nltk.stem.snowball import SnowballStemmer
|
|
3
|
+
from rouge_score import rouge_scorer
|
|
4
|
+
from sentence_transformers import SentenceTransformer, util
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def stem_texts(texts: list[str]) -> list[str]:
|
|
8
|
+
stemmer = SnowballStemmer("portuguese")
|
|
9
|
+
|
|
10
|
+
stemmed_texts: list[str] = []
|
|
11
|
+
for text in texts:
|
|
12
|
+
stemmed_text = " ".join([stemmer.stem(word) for word in text.split()])
|
|
13
|
+
stemmed_texts.append(stemmed_text)
|
|
14
|
+
|
|
15
|
+
return stemmed_texts
|
|
16
|
+
|
|
17
|
+
def rouge_evaluation(
|
|
18
|
+
preds: list[str],
|
|
19
|
+
refs: list[str]
|
|
20
|
+
) -> dict[str, float]:
|
|
21
|
+
preds_stemmed = stem_texts(preds)
|
|
22
|
+
refs_stemmed = stem_texts(refs)
|
|
23
|
+
|
|
24
|
+
rouge_metrics = {"rouge1": [], "rouge2": [], "rougeL": []}
|
|
25
|
+
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
|
|
26
|
+
|
|
27
|
+
for ref, pred in zip(refs_stemmed, preds_stemmed):
|
|
28
|
+
scores = scorer.score(ref, pred)
|
|
29
|
+
for key in rouge_metrics:
|
|
30
|
+
rouge_metrics[key].append(scores[key].fmeasure)
|
|
31
|
+
|
|
32
|
+
return {k: sum(v)/len(v) for k, v in rouge_metrics.items()}
|
|
33
|
+
|
|
34
|
+
def bert_score_evaluation(
|
|
35
|
+
preds: list[str],
|
|
36
|
+
refs: list[str]
|
|
37
|
+
) -> dict[str, float]:
|
|
38
|
+
bertscore = load("bertscore")
|
|
39
|
+
|
|
40
|
+
bert_result = bertscore.compute(predictions=preds, references=refs, lang="pt")
|
|
41
|
+
|
|
42
|
+
bert_avg = {}
|
|
43
|
+
if bert_result:
|
|
44
|
+
bert_avg = {
|
|
45
|
+
"bertscore_precision": sum(bert_result["precision"]) / len(bert_result["precision"]),
|
|
46
|
+
"bertscore_recall": sum(bert_result["recall"]) / len(bert_result["recall"]),
|
|
47
|
+
"bertscore_f1": sum(bert_result["f1"]) / len(bert_result["f1"])
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
return bert_avg
|
|
51
|
+
|
|
52
|
+
def cosine_similarity_evaluation(
|
|
53
|
+
preds: list[str],
|
|
54
|
+
refs: list[str]
|
|
55
|
+
) -> dict[str, float]:
|
|
56
|
+
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
|
57
|
+
|
|
58
|
+
emb_preds = model.encode(preds, convert_to_tensor=True)
|
|
59
|
+
emb_refs = model.encode(refs, convert_to_tensor=True)
|
|
60
|
+
|
|
61
|
+
cos_sim_matrix = util.cos_sim(emb_preds, emb_refs)
|
|
62
|
+
|
|
63
|
+
cos_sim_scores = cos_sim_matrix.diag()
|
|
64
|
+
avg_cos_sim = cos_sim_scores.mean().item()
|
|
65
|
+
|
|
66
|
+
return {"cosine_similarity": float(avg_cos_sim)}
|
|
67
|
+
|
|
68
|
+
def text_evaluation(
|
|
69
|
+
preds: list[str],
|
|
70
|
+
refs: list[str],
|
|
71
|
+
rouge: bool = True,
|
|
72
|
+
bert: bool = True,
|
|
73
|
+
cosine: bool = True
|
|
74
|
+
) -> dict[str, float]:
|
|
75
|
+
result = {}
|
|
76
|
+
if rouge:
|
|
77
|
+
result.update(rouge_evaluation(
|
|
78
|
+
preds=preds,
|
|
79
|
+
refs=refs
|
|
80
|
+
))
|
|
81
|
+
if bert:
|
|
82
|
+
result.update(bert_score_evaluation(
|
|
83
|
+
preds=preds,
|
|
84
|
+
refs=refs
|
|
85
|
+
))
|
|
86
|
+
if cosine:
|
|
87
|
+
result.update(cosine_similarity_evaluation(
|
|
88
|
+
preds=preds,
|
|
89
|
+
refs=refs
|
|
90
|
+
))
|
|
91
|
+
|
|
92
|
+
return result
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
2
|
+
|
|
3
|
+
from llmflowstack.schemas.params import GenerationParams
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_generation_params(generation_configs: GenerationParams) -> GenerationConfig:
|
|
7
|
+
|
|
8
|
+
params = {
|
|
9
|
+
"max_new_tokens": generation_configs.max_new_tokens,
|
|
10
|
+
"repetition_penalty": generation_configs.repetition_penalty
|
|
11
|
+
}
|
|
12
|
+
if generation_configs.sample:
|
|
13
|
+
sample = generation_configs.sample
|
|
14
|
+
params.update({
|
|
15
|
+
"do_sample": True,
|
|
16
|
+
"temperature": sample.temperature,
|
|
17
|
+
"top_p": sample.top_p,
|
|
18
|
+
"typical_p": sample.typical_p,
|
|
19
|
+
"num_beams": 1
|
|
20
|
+
})
|
|
21
|
+
elif generation_configs.beams == "beams":
|
|
22
|
+
beams = generation_configs.beams
|
|
23
|
+
params.update({
|
|
24
|
+
"do_sample": False,
|
|
25
|
+
"num_beams": beams.num_beams,
|
|
26
|
+
"length_penalty": beams.length_penalty,
|
|
27
|
+
"early_stopping": beams.early_stopping
|
|
28
|
+
})
|
|
29
|
+
|
|
30
|
+
return GenerationConfig(**params)
|