optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,700 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
import logging
|
26
|
+
import warnings
|
27
|
+
from pathlib import Path
|
28
|
+
from tempfile import TemporaryDirectory
|
29
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
30
|
+
|
31
|
+
import rebel
|
32
|
+
import torch
|
33
|
+
from optimum.exporters import TasksManager
|
34
|
+
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
35
|
+
from transformers.generation.logits_process import LogitsProcessorList
|
36
|
+
from transformers.generation.stopping_criteria import (
|
37
|
+
StoppingCriteriaList,
|
38
|
+
validate_stopping_criteria,
|
39
|
+
)
|
40
|
+
from transformers.generation.streamers import BaseStreamer
|
41
|
+
from transformers.generation.utils import SampleDecoderOnlyOutput
|
42
|
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
43
|
+
|
44
|
+
from ....modeling_base import RBLNBaseModel
|
45
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
46
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
47
|
+
from ....utils.save_utils import maybe_save_preprocessors
|
48
|
+
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
49
|
+
|
50
|
+
|
51
|
+
logger = logging.getLogger(__name__)
|
52
|
+
|
53
|
+
if TYPE_CHECKING:
|
54
|
+
from transformers import (
|
55
|
+
AutoFeatureExtractor,
|
56
|
+
AutoProcessor,
|
57
|
+
AutoTokenizer,
|
58
|
+
PretrainedConfig,
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
63
|
+
def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
|
64
|
+
outputs = super().forward(*args, **kwargs)
|
65
|
+
logits = outputs
|
66
|
+
return Seq2SeqLMOutput(logits=logits)
|
67
|
+
|
68
|
+
|
69
|
+
class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
70
|
+
"""
|
71
|
+
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
72
|
+
embeddings).
|
73
|
+
|
74
|
+
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
|
75
|
+
library implements for all its model.
|
76
|
+
|
77
|
+
It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
|
78
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
79
|
+
- compiling the resulting graph using the RBLN compiler.
|
80
|
+
|
81
|
+
"""
|
82
|
+
|
83
|
+
model_type = "rbln_model"
|
84
|
+
auto_model_class = AutoModelForCausalLM
|
85
|
+
main_input_name = "input_ids"
|
86
|
+
|
87
|
+
def __post_init__(self, **kwargs):
|
88
|
+
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
89
|
+
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
90
|
+
|
91
|
+
batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
|
92
|
+
self.prefill_attention_mask = torch.zeros(
|
93
|
+
batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
94
|
+
)
|
95
|
+
self.causal_mask = 1 - torch.triu(
|
96
|
+
torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
97
|
+
)
|
98
|
+
|
99
|
+
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0])
|
100
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1])
|
101
|
+
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
102
|
+
self.past_cached_length = 0
|
103
|
+
|
104
|
+
def can_generate(self):
|
105
|
+
return True
|
106
|
+
|
107
|
+
def __getattr__(self, __name: str) -> Any:
|
108
|
+
"""This is the key method to implement RBLN-GPT2.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
Any: GPT2's corresponding method
|
112
|
+
"""
|
113
|
+
|
114
|
+
def redirect(func):
|
115
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
116
|
+
|
117
|
+
val = getattr(GPT2LMHeadModel, __name)
|
118
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
119
|
+
return redirect(val)
|
120
|
+
return val
|
121
|
+
|
122
|
+
def _reorder_cache(self, past_key_values, beam_idx):
|
123
|
+
# TODO(jongho): implement
|
124
|
+
raise NotImplementedError
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def _export(
|
128
|
+
cls,
|
129
|
+
model_id: str,
|
130
|
+
config: "PretrainedConfig",
|
131
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
132
|
+
revision: Optional[str] = None,
|
133
|
+
force_download: bool = False,
|
134
|
+
cache_dir: Optional[str] = None,
|
135
|
+
subfolder: str = "",
|
136
|
+
local_files_only: bool = False,
|
137
|
+
trust_remote_code: bool = False,
|
138
|
+
**kwargs,
|
139
|
+
) -> "RBLNGPT2LMHeadModel":
|
140
|
+
"""
|
141
|
+
Exports a vanilla Transformers model into a rbln-compiled Module.
|
142
|
+
"""
|
143
|
+
task = kwargs.pop("task", None)
|
144
|
+
if task is None:
|
145
|
+
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
146
|
+
|
147
|
+
save_dir = TemporaryDirectory()
|
148
|
+
save_dir_path = Path(save_dir.name)
|
149
|
+
|
150
|
+
kwargs.update(
|
151
|
+
{
|
152
|
+
"torchscript": True,
|
153
|
+
"return_dict": False,
|
154
|
+
"use_cache": True,
|
155
|
+
}
|
156
|
+
)
|
157
|
+
|
158
|
+
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
159
|
+
|
160
|
+
model: GPT2LMHeadModel = TasksManager.get_model_from_task(
|
161
|
+
task=task,
|
162
|
+
model_name_or_path=model_id,
|
163
|
+
subfolder=subfolder,
|
164
|
+
revision=revision,
|
165
|
+
framework="pt",
|
166
|
+
cache_dir=cache_dir,
|
167
|
+
use_auth_token=use_auth_token,
|
168
|
+
local_files_only=local_files_only,
|
169
|
+
force_download=force_download,
|
170
|
+
trust_remote_code=trust_remote_code,
|
171
|
+
**kwargs,
|
172
|
+
)
|
173
|
+
|
174
|
+
if config is None:
|
175
|
+
config = model.config
|
176
|
+
|
177
|
+
config.save_pretrained(save_dir_path)
|
178
|
+
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
179
|
+
|
180
|
+
# Get compilation arguments
|
181
|
+
if rbln_config_kwargs.get("rbln_config", None) is None:
|
182
|
+
rbln_config = cls.get_rbln_config(
|
183
|
+
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
184
|
+
)
|
185
|
+
|
186
|
+
def compile_gpt2():
|
187
|
+
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
188
|
+
|
189
|
+
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
190
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
191
|
+
|
192
|
+
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
193
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
194
|
+
|
195
|
+
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
|
196
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
|
197
|
+
|
198
|
+
prefill_ir = rebel.torchscript_to_ir(
|
199
|
+
prefill_scripted_model,
|
200
|
+
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
201
|
+
)
|
202
|
+
dec_ir = rebel.torchscript_to_ir(
|
203
|
+
dec_scripted_model,
|
204
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
205
|
+
)
|
206
|
+
|
207
|
+
connections = [
|
208
|
+
(prefill_ir.outputs[1], prefill_ir.inputs[1]),
|
209
|
+
]
|
210
|
+
|
211
|
+
compiled_model = rebel.compile(
|
212
|
+
prefill_ir,
|
213
|
+
dec_ir,
|
214
|
+
connections=connections,
|
215
|
+
fusion=prefill_rbln_runtime_config.fusion,
|
216
|
+
npu=prefill_rbln_runtime_config.npu,
|
217
|
+
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
218
|
+
use_weight_sharing=True,
|
219
|
+
)
|
220
|
+
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
221
|
+
|
222
|
+
compile_gpt2()
|
223
|
+
rbln_config.save(save_dir_path)
|
224
|
+
|
225
|
+
return cls._from_pretrained(
|
226
|
+
model_id=save_dir_path,
|
227
|
+
config=config,
|
228
|
+
model_save_dir=save_dir,
|
229
|
+
**rbln_constructor_kwargs,
|
230
|
+
**kwargs,
|
231
|
+
)
|
232
|
+
|
233
|
+
@classmethod
|
234
|
+
def _get_rbln_config(
|
235
|
+
cls,
|
236
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
237
|
+
model_config: "PretrainedConfig",
|
238
|
+
rbln_max_seq_len: Optional[int] = None,
|
239
|
+
rbln_batch_size: Optional[int] = None,
|
240
|
+
rbln_pad_token_id: Optional[int] = None,
|
241
|
+
) -> RBLNConfig:
|
242
|
+
meta = {}
|
243
|
+
|
244
|
+
default_max_length = getattr(model_config, "n_positions", None)
|
245
|
+
for tokenizer in preprocessors:
|
246
|
+
default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
|
247
|
+
|
248
|
+
prefill_chunk_size = 128
|
249
|
+
|
250
|
+
if rbln_max_seq_len is None:
|
251
|
+
rbln_max_seq_len = default_max_length
|
252
|
+
|
253
|
+
if rbln_max_seq_len is None:
|
254
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
255
|
+
|
256
|
+
if rbln_pad_token_id is None:
|
257
|
+
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
258
|
+
if rbln_pad_token_id is None:
|
259
|
+
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
260
|
+
if rbln_pad_token_id is None:
|
261
|
+
rbln_pad_token_id = 50256
|
262
|
+
|
263
|
+
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
264
|
+
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
265
|
+
meta["rbln_pad_token_id"] = rbln_pad_token_id
|
266
|
+
|
267
|
+
if rbln_batch_size is None:
|
268
|
+
rbln_batch_size = 1
|
269
|
+
|
270
|
+
def get_input_info(query_length):
|
271
|
+
return [
|
272
|
+
("input_ids", [rbln_batch_size, query_length], "int64"),
|
273
|
+
(
|
274
|
+
"past_key_values",
|
275
|
+
[
|
276
|
+
model_config.n_layer,
|
277
|
+
2,
|
278
|
+
rbln_batch_size,
|
279
|
+
model_config.n_head,
|
280
|
+
rbln_max_seq_len,
|
281
|
+
model_config.hidden_size // model_config.n_head,
|
282
|
+
],
|
283
|
+
"float32",
|
284
|
+
),
|
285
|
+
("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
286
|
+
(
|
287
|
+
"cache_position",
|
288
|
+
[],
|
289
|
+
"int32",
|
290
|
+
),
|
291
|
+
]
|
292
|
+
|
293
|
+
# model input info
|
294
|
+
prefill_input_info = get_input_info(query_length=prefill_chunk_size)
|
295
|
+
dec_input_info = get_input_info(query_length=1)
|
296
|
+
|
297
|
+
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
298
|
+
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
299
|
+
|
300
|
+
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
301
|
+
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
302
|
+
_rbln_meta=meta,
|
303
|
+
)
|
304
|
+
|
305
|
+
return rbln_config
|
306
|
+
|
307
|
+
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
308
|
+
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
309
|
+
return [
|
310
|
+
self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
311
|
+
self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
312
|
+
]
|
313
|
+
|
314
|
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
315
|
+
batch_size, cur_len = input_ids.shape
|
316
|
+
past_cached_length = past_key_values
|
317
|
+
|
318
|
+
# In greedy decoding
|
319
|
+
if past_cached_length == 0:
|
320
|
+
self.prompt_ids = input_ids
|
321
|
+
self.rightpad_max_len = cur_len
|
322
|
+
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
323
|
+
|
324
|
+
if cur_len % self.prefill_chunk_size == 0:
|
325
|
+
pad_len = 0
|
326
|
+
else:
|
327
|
+
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
328
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
329
|
+
attention_mask = self.prefill_attention_mask.clone()
|
330
|
+
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
331
|
+
|
332
|
+
query_length = prompt_min_len
|
333
|
+
else:
|
334
|
+
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
335
|
+
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
336
|
+
attention_mask[:, :, :, : cache_position + 1] = 1
|
337
|
+
input_ids = input_ids[:, -1:].contiguous()
|
338
|
+
query_length = 1
|
339
|
+
|
340
|
+
model_inputs = {
|
341
|
+
"input_ids": input_ids,
|
342
|
+
"past_key_values": past_key_values,
|
343
|
+
"attention_mask": attention_mask,
|
344
|
+
# below are rbln-related kwargs
|
345
|
+
"cache_position": cache_position,
|
346
|
+
"query_length": query_length,
|
347
|
+
}
|
348
|
+
|
349
|
+
return model_inputs
|
350
|
+
|
351
|
+
def forward(
|
352
|
+
self,
|
353
|
+
input_ids: Optional[torch.LongTensor] = None,
|
354
|
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
355
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
356
|
+
cache_position: Optional[torch.Tensor] = None,
|
357
|
+
query_length: Optional[torch.Tensor] = None,
|
358
|
+
**kwargs,
|
359
|
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
360
|
+
|
361
|
+
if past_key_values is not None:
|
362
|
+
past_key_values += query_length
|
363
|
+
|
364
|
+
if cache_position == 0:
|
365
|
+
for _ in range(0, query_length, self.prefill_chunk_size):
|
366
|
+
sliced_input_ids = input_ids[:, cache_position : cache_position + self.prefill_chunk_size]
|
367
|
+
attention_mask[:, :, :, :cache_position] = 1
|
368
|
+
attention_mask[:, :, :, cache_position : cache_position + self.prefill_chunk_size] = self.causal_mask
|
369
|
+
|
370
|
+
output = self.prefill_decoder(
|
371
|
+
input_ids=sliced_input_ids.contiguous(),
|
372
|
+
attention_mask=attention_mask.contiguous(),
|
373
|
+
cache_position=cache_position,
|
374
|
+
)
|
375
|
+
query_length -= self.prefill_chunk_size
|
376
|
+
cache_position += self.prefill_chunk_size
|
377
|
+
|
378
|
+
output = output.logits[:, query_length - 1].unsqueeze(1)
|
379
|
+
|
380
|
+
else:
|
381
|
+
output = self.decoder(
|
382
|
+
input_ids=input_ids.contiguous(),
|
383
|
+
attention_mask=attention_mask.contiguous(),
|
384
|
+
cache_position=cache_position,
|
385
|
+
)
|
386
|
+
output = output.logits
|
387
|
+
|
388
|
+
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
|
389
|
+
|
390
|
+
def __repr__(self):
|
391
|
+
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
392
|
+
|
393
|
+
# call 'greedy_search` directly is deprecated and removed in v4.41.
|
394
|
+
def greedy_search(self, *args, **kwargs):
|
395
|
+
return self._greedy_search(*args, **kwargs)
|
396
|
+
|
397
|
+
def _greedy_search(
|
398
|
+
self,
|
399
|
+
input_ids: torch.LongTensor,
|
400
|
+
logits_processor: Optional[LogitsProcessorList] = None,
|
401
|
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
402
|
+
max_length: Optional[int] = None,
|
403
|
+
pad_token_id: Optional[int] = None,
|
404
|
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
405
|
+
output_logits: Optional[bool] = None,
|
406
|
+
return_dict_in_generate: Optional[bool] = None,
|
407
|
+
streamer: Optional["BaseStreamer"] = None,
|
408
|
+
**model_kwargs,
|
409
|
+
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
410
|
+
|
411
|
+
# init values
|
412
|
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
413
|
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
414
|
+
|
415
|
+
if max_length is not None:
|
416
|
+
warnings.warn(
|
417
|
+
"`max_length` is deprecated in this function, use"
|
418
|
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
419
|
+
UserWarning,
|
420
|
+
)
|
421
|
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
422
|
+
|
423
|
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
424
|
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
425
|
+
if isinstance(eos_token_id, int):
|
426
|
+
eos_token_id = [eos_token_id]
|
427
|
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
428
|
+
|
429
|
+
return_dict_in_generate = (
|
430
|
+
return_dict_in_generate
|
431
|
+
if return_dict_in_generate is not None
|
432
|
+
else self.generation_config.return_dict_in_generate
|
433
|
+
)
|
434
|
+
|
435
|
+
# init attention / hidden states / scores tuples
|
436
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
437
|
+
|
438
|
+
# keep track of which sequences are already finished
|
439
|
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
440
|
+
|
441
|
+
this_peer_finished = False # used by synced_gpus only
|
442
|
+
|
443
|
+
while True:
|
444
|
+
# prepare model inputs
|
445
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
446
|
+
|
447
|
+
# forward pass to get next token
|
448
|
+
outputs = self(
|
449
|
+
**model_inputs,
|
450
|
+
return_dict=True,
|
451
|
+
)
|
452
|
+
next_token_logits = outputs.logits[:, -1, :]
|
453
|
+
|
454
|
+
# pre-process distribution
|
455
|
+
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
456
|
+
|
457
|
+
# Store scores, attentions and hidden_states when required
|
458
|
+
if return_dict_in_generate:
|
459
|
+
if output_logits:
|
460
|
+
raw_logits += (next_token_logits,)
|
461
|
+
|
462
|
+
# argmax
|
463
|
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
464
|
+
|
465
|
+
# finished sentences should have their next token be a padding token
|
466
|
+
if eos_token_id is not None:
|
467
|
+
if pad_token_id is None:
|
468
|
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
469
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
470
|
+
|
471
|
+
########################################################################################################
|
472
|
+
# thkim change for right-padding batch
|
473
|
+
# if min_input_len <= update_idx < max_input_len
|
474
|
+
# update validate input_ids[:,update_idx]
|
475
|
+
# TODO : raw_logits contains dummy next_token's logits
|
476
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
477
|
+
if update_idx < self.rightpad_max_len:
|
478
|
+
# update exist input_ids rather than concat
|
479
|
+
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
480
|
+
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
481
|
+
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
482
|
+
|
483
|
+
# dummy next_token -> pad_token_id for streamer
|
484
|
+
# in order to skip by 'skip_special_tokens = True"
|
485
|
+
dummy_indices = ~valid_indices
|
486
|
+
next_tokens[dummy_indices] = pad_token_id
|
487
|
+
else:
|
488
|
+
############################################END#########################################################
|
489
|
+
# update generated ids, model inputs, and length for next step
|
490
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
491
|
+
|
492
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
493
|
+
outputs,
|
494
|
+
model_kwargs,
|
495
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
496
|
+
)
|
497
|
+
|
498
|
+
if streamer is not None:
|
499
|
+
streamer.put(next_tokens.cpu())
|
500
|
+
|
501
|
+
# if eos_token was found in one sentence, set sentence to finished
|
502
|
+
if eos_token_id_tensor is not None:
|
503
|
+
####################################################################
|
504
|
+
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
505
|
+
if hasattr(self, "rightpad_max_len"):
|
506
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
507
|
+
if update_idx < self.rightpad_max_len:
|
508
|
+
next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
|
509
|
+
######################################################################
|
510
|
+
unfinished_sequences = unfinished_sequences.mul(
|
511
|
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
512
|
+
)
|
513
|
+
|
514
|
+
# stop when each sentence is finished
|
515
|
+
if unfinished_sequences.max() == 0:
|
516
|
+
this_peer_finished = True
|
517
|
+
|
518
|
+
# stop if we exceed the maximum length
|
519
|
+
# thkim : backward compatibility bool vs torch.BoolTensor
|
520
|
+
is_stop = stopping_criteria(input_ids, None)
|
521
|
+
if isinstance(is_stop, torch.BoolTensor):
|
522
|
+
is_stop = torch.all(is_stop)
|
523
|
+
if is_stop:
|
524
|
+
this_peer_finished = True
|
525
|
+
|
526
|
+
if this_peer_finished:
|
527
|
+
break
|
528
|
+
|
529
|
+
if streamer is not None:
|
530
|
+
streamer.end()
|
531
|
+
|
532
|
+
if return_dict_in_generate:
|
533
|
+
return SampleDecoderOnlyOutput(
|
534
|
+
sequences=input_ids,
|
535
|
+
logits=raw_logits,
|
536
|
+
)
|
537
|
+
else:
|
538
|
+
return input_ids
|
539
|
+
|
540
|
+
# call 'sample` directly is deprecated and removed in v4.41.
|
541
|
+
def sample(self, *args, **kwargs):
|
542
|
+
return self._sample(*args, **kwargs)
|
543
|
+
|
544
|
+
def _sample(
|
545
|
+
self,
|
546
|
+
input_ids: torch.LongTensor,
|
547
|
+
logits_processor: Optional[LogitsProcessorList] = None,
|
548
|
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
549
|
+
logits_warper: Optional[LogitsProcessorList] = None,
|
550
|
+
max_length: Optional[int] = None,
|
551
|
+
pad_token_id: Optional[int] = None,
|
552
|
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
553
|
+
output_attentions: Optional[bool] = None,
|
554
|
+
output_hidden_states: Optional[bool] = None,
|
555
|
+
output_scores: Optional[bool] = None,
|
556
|
+
output_logits: Optional[bool] = None,
|
557
|
+
return_dict_in_generate: Optional[bool] = None,
|
558
|
+
synced_gpus: bool = False,
|
559
|
+
streamer: Optional["BaseStreamer"] = None,
|
560
|
+
**model_kwargs,
|
561
|
+
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
562
|
+
# init values
|
563
|
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
564
|
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
565
|
+
|
566
|
+
if max_length is not None:
|
567
|
+
warnings.warn(
|
568
|
+
"`max_length` is deprecated in this function, use"
|
569
|
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
570
|
+
UserWarning,
|
571
|
+
)
|
572
|
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
573
|
+
|
574
|
+
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
575
|
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
576
|
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
577
|
+
|
578
|
+
if isinstance(eos_token_id, int):
|
579
|
+
eos_token_id = [eos_token_id]
|
580
|
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
581
|
+
|
582
|
+
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
583
|
+
output_logits = output_logits if output_logits is not None else False
|
584
|
+
|
585
|
+
# init attention / hidden states / scores tuples
|
586
|
+
scores = () if (return_dict_in_generate and output_scores) else None
|
587
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
588
|
+
|
589
|
+
# keep track of which sequences are already finished
|
590
|
+
batch_size, cur_len = input_ids.shape
|
591
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
592
|
+
this_peer_finished = False
|
593
|
+
|
594
|
+
# model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
595
|
+
|
596
|
+
while True:
|
597
|
+
# prepare model inputs
|
598
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
599
|
+
|
600
|
+
# forward pass to get next token
|
601
|
+
outputs = self(
|
602
|
+
**model_inputs,
|
603
|
+
return_dict=True,
|
604
|
+
output_attentions=output_attentions,
|
605
|
+
output_hidden_states=output_hidden_states,
|
606
|
+
)
|
607
|
+
|
608
|
+
next_token_logits = outputs.logits[:, -1, :]
|
609
|
+
|
610
|
+
# pre-process distribution
|
611
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
612
|
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
613
|
+
|
614
|
+
# Store scores, attentions and hidden_states when required
|
615
|
+
if return_dict_in_generate:
|
616
|
+
if output_scores:
|
617
|
+
scores += (next_token_scores,)
|
618
|
+
if output_logits:
|
619
|
+
raw_logits += (next_token_logits,)
|
620
|
+
|
621
|
+
# sample
|
622
|
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
623
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
624
|
+
|
625
|
+
# finished sentences should have their next token be a padding token
|
626
|
+
if eos_token_id is not None:
|
627
|
+
if pad_token_id is None:
|
628
|
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
629
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
630
|
+
|
631
|
+
########################################################################################################
|
632
|
+
# thkim change for right-padding batch
|
633
|
+
# if min_input_len <= update_idx < max_input_len
|
634
|
+
# update validate input_ids[:,update_idx]
|
635
|
+
# TODO : raw_logits contains dummy next_token's logits
|
636
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
637
|
+
if update_idx < self.rightpad_max_len:
|
638
|
+
# update exist input_ids rather than concat
|
639
|
+
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
640
|
+
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
641
|
+
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
642
|
+
|
643
|
+
# dummy next_token -> pad_token_id for streamer
|
644
|
+
# in order to skip by 'skip_special_tokens = True"
|
645
|
+
dummy_indices = ~valid_indices
|
646
|
+
next_tokens[dummy_indices] = pad_token_id
|
647
|
+
else:
|
648
|
+
############################################END#########################################################
|
649
|
+
|
650
|
+
# update generated ids, model inputs, and length for next step
|
651
|
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
652
|
+
|
653
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
654
|
+
outputs,
|
655
|
+
model_kwargs,
|
656
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
657
|
+
)
|
658
|
+
|
659
|
+
if streamer is not None:
|
660
|
+
streamer.put(next_tokens.cpu())
|
661
|
+
|
662
|
+
# if eos_token was found in one sentence, set sentence to finished
|
663
|
+
if eos_token_id_tensor is not None:
|
664
|
+
####################################################################
|
665
|
+
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
666
|
+
if hasattr(self, "rightpad_max_len"):
|
667
|
+
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
668
|
+
if update_idx < self.rightpad_max_len:
|
669
|
+
next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
|
670
|
+
######################################################################
|
671
|
+
unfinished_sequences = unfinished_sequences.mul(
|
672
|
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
673
|
+
)
|
674
|
+
|
675
|
+
# stop when each sentence is finished
|
676
|
+
if unfinished_sequences.max() == 0:
|
677
|
+
this_peer_finished = True
|
678
|
+
|
679
|
+
# stop if we exceed the maximum length
|
680
|
+
# thkim : backward compatibility bool vs list[bool]
|
681
|
+
is_stop = stopping_criteria(input_ids, None)
|
682
|
+
if isinstance(is_stop, torch.BoolTensor):
|
683
|
+
is_stop = torch.all(is_stop)
|
684
|
+
if is_stop:
|
685
|
+
this_peer_finished = True
|
686
|
+
|
687
|
+
if this_peer_finished:
|
688
|
+
break
|
689
|
+
|
690
|
+
if streamer is not None:
|
691
|
+
streamer.end()
|
692
|
+
|
693
|
+
if return_dict_in_generate:
|
694
|
+
return SampleDecoderOnlyOutput(
|
695
|
+
sequences=input_ids,
|
696
|
+
scores=scores,
|
697
|
+
logits=raw_logits,
|
698
|
+
)
|
699
|
+
else:
|
700
|
+
return input_ids
|