optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,700 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ import warnings
27
+ from pathlib import Path
28
+ from tempfile import TemporaryDirectory
29
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
30
+
31
+ import rebel
32
+ import torch
33
+ from optimum.exporters import TasksManager
34
+ from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
35
+ from transformers.generation.logits_process import LogitsProcessorList
36
+ from transformers.generation.stopping_criteria import (
37
+ StoppingCriteriaList,
38
+ validate_stopping_criteria,
39
+ )
40
+ from transformers.generation.streamers import BaseStreamer
41
+ from transformers.generation.utils import SampleDecoderOnlyOutput
42
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
43
+
44
+ from ....modeling_base import RBLNBaseModel
45
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
46
+ from ....utils.runtime_utils import RBLNPytorchRuntime
47
+ from ....utils.save_utils import maybe_save_preprocessors
48
+ from .gpt2_architecture import GPT2LMHeadModelWrapper
49
+
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ if TYPE_CHECKING:
54
+ from transformers import (
55
+ AutoFeatureExtractor,
56
+ AutoProcessor,
57
+ AutoTokenizer,
58
+ PretrainedConfig,
59
+ )
60
+
61
+
62
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
63
+ def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
64
+ outputs = super().forward(*args, **kwargs)
65
+ logits = outputs
66
+ return Seq2SeqLMOutput(logits=logits)
67
+
68
+
69
+ class RBLNGPT2LMHeadModel(RBLNBaseModel):
70
+ """
71
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
72
+ embeddings).
73
+
74
+ This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
75
+ library implements for all its model.
76
+
77
+ It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
78
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
79
+ - compiling the resulting graph using the RBLN compiler.
80
+
81
+ """
82
+
83
+ model_type = "rbln_model"
84
+ auto_model_class = AutoModelForCausalLM
85
+ main_input_name = "input_ids"
86
+
87
+ def __post_init__(self, **kwargs):
88
+ self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
89
+ self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
90
+
91
+ batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
92
+ self.prefill_attention_mask = torch.zeros(
93
+ batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
94
+ )
95
+ self.causal_mask = 1 - torch.triu(
96
+ torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
97
+ )
98
+
99
+ self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0])
100
+ self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1])
101
+ self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
102
+ self.past_cached_length = 0
103
+
104
+ def can_generate(self):
105
+ return True
106
+
107
+ def __getattr__(self, __name: str) -> Any:
108
+ """This is the key method to implement RBLN-GPT2.
109
+
110
+ Returns:
111
+ Any: GPT2's corresponding method
112
+ """
113
+
114
+ def redirect(func):
115
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
116
+
117
+ val = getattr(GPT2LMHeadModel, __name)
118
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
119
+ return redirect(val)
120
+ return val
121
+
122
+ def _reorder_cache(self, past_key_values, beam_idx):
123
+ # TODO(jongho): implement
124
+ raise NotImplementedError
125
+
126
+ @classmethod
127
+ def _export(
128
+ cls,
129
+ model_id: str,
130
+ config: "PretrainedConfig",
131
+ use_auth_token: Optional[Union[bool, str]] = None,
132
+ revision: Optional[str] = None,
133
+ force_download: bool = False,
134
+ cache_dir: Optional[str] = None,
135
+ subfolder: str = "",
136
+ local_files_only: bool = False,
137
+ trust_remote_code: bool = False,
138
+ **kwargs,
139
+ ) -> "RBLNGPT2LMHeadModel":
140
+ """
141
+ Exports a vanilla Transformers model into a rbln-compiled Module.
142
+ """
143
+ task = kwargs.pop("task", None)
144
+ if task is None:
145
+ task = TasksManager.infer_task_from_model(cls.auto_model_class)
146
+
147
+ save_dir = TemporaryDirectory()
148
+ save_dir_path = Path(save_dir.name)
149
+
150
+ kwargs.update(
151
+ {
152
+ "torchscript": True,
153
+ "return_dict": False,
154
+ "use_cache": True,
155
+ }
156
+ )
157
+
158
+ rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
159
+
160
+ model: GPT2LMHeadModel = TasksManager.get_model_from_task(
161
+ task=task,
162
+ model_name_or_path=model_id,
163
+ subfolder=subfolder,
164
+ revision=revision,
165
+ framework="pt",
166
+ cache_dir=cache_dir,
167
+ use_auth_token=use_auth_token,
168
+ local_files_only=local_files_only,
169
+ force_download=force_download,
170
+ trust_remote_code=trust_remote_code,
171
+ **kwargs,
172
+ )
173
+
174
+ if config is None:
175
+ config = model.config
176
+
177
+ config.save_pretrained(save_dir_path)
178
+ preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
179
+
180
+ # Get compilation arguments
181
+ if rbln_config_kwargs.get("rbln_config", None) is None:
182
+ rbln_config = cls.get_rbln_config(
183
+ preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
184
+ )
185
+
186
+ def compile_gpt2():
187
+ wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
188
+
189
+ prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
190
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
191
+
192
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
193
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
194
+
195
+ prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
196
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
197
+
198
+ prefill_ir = rebel.torchscript_to_ir(
199
+ prefill_scripted_model,
200
+ input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
201
+ )
202
+ dec_ir = rebel.torchscript_to_ir(
203
+ dec_scripted_model,
204
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
205
+ )
206
+
207
+ connections = [
208
+ (prefill_ir.outputs[1], prefill_ir.inputs[1]),
209
+ ]
210
+
211
+ compiled_model = rebel.compile(
212
+ prefill_ir,
213
+ dec_ir,
214
+ connections=connections,
215
+ fusion=prefill_rbln_runtime_config.fusion,
216
+ npu=prefill_rbln_runtime_config.npu,
217
+ tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
218
+ use_weight_sharing=True,
219
+ )
220
+ compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
221
+
222
+ compile_gpt2()
223
+ rbln_config.save(save_dir_path)
224
+
225
+ return cls._from_pretrained(
226
+ model_id=save_dir_path,
227
+ config=config,
228
+ model_save_dir=save_dir,
229
+ **rbln_constructor_kwargs,
230
+ **kwargs,
231
+ )
232
+
233
+ @classmethod
234
+ def _get_rbln_config(
235
+ cls,
236
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
237
+ model_config: "PretrainedConfig",
238
+ rbln_max_seq_len: Optional[int] = None,
239
+ rbln_batch_size: Optional[int] = None,
240
+ rbln_pad_token_id: Optional[int] = None,
241
+ ) -> RBLNConfig:
242
+ meta = {}
243
+
244
+ default_max_length = getattr(model_config, "n_positions", None)
245
+ for tokenizer in preprocessors:
246
+ default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
247
+
248
+ prefill_chunk_size = 128
249
+
250
+ if rbln_max_seq_len is None:
251
+ rbln_max_seq_len = default_max_length
252
+
253
+ if rbln_max_seq_len is None:
254
+ raise ValueError("`rbln_max_seq_len` should be specified!")
255
+
256
+ if rbln_pad_token_id is None:
257
+ rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
258
+ if rbln_pad_token_id is None:
259
+ rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
260
+ if rbln_pad_token_id is None:
261
+ rbln_pad_token_id = 50256
262
+
263
+ meta["rbln_prefill_chunk_size"] = prefill_chunk_size
264
+ meta["rbln_max_seq_len"] = rbln_max_seq_len
265
+ meta["rbln_pad_token_id"] = rbln_pad_token_id
266
+
267
+ if rbln_batch_size is None:
268
+ rbln_batch_size = 1
269
+
270
+ def get_input_info(query_length):
271
+ return [
272
+ ("input_ids", [rbln_batch_size, query_length], "int64"),
273
+ (
274
+ "past_key_values",
275
+ [
276
+ model_config.n_layer,
277
+ 2,
278
+ rbln_batch_size,
279
+ model_config.n_head,
280
+ rbln_max_seq_len,
281
+ model_config.hidden_size // model_config.n_head,
282
+ ],
283
+ "float32",
284
+ ),
285
+ ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
286
+ (
287
+ "cache_position",
288
+ [],
289
+ "int32",
290
+ ),
291
+ ]
292
+
293
+ # model input info
294
+ prefill_input_info = get_input_info(query_length=prefill_chunk_size)
295
+ dec_input_info = get_input_info(query_length=1)
296
+
297
+ prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
298
+ dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
299
+
300
+ rbln_config = RBLNConfig.from_rbln_runtime_configs(
301
+ [prefill_rbln_runtime_config, dec_rbln_runtime_config],
302
+ _rbln_meta=meta,
303
+ )
304
+
305
+ return rbln_config
306
+
307
+ def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
308
+ device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
309
+ return [
310
+ self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
311
+ self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
312
+ ]
313
+
314
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
315
+ batch_size, cur_len = input_ids.shape
316
+ past_cached_length = past_key_values
317
+
318
+ # In greedy decoding
319
+ if past_cached_length == 0:
320
+ self.prompt_ids = input_ids
321
+ self.rightpad_max_len = cur_len
322
+ prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
323
+
324
+ if cur_len % self.prefill_chunk_size == 0:
325
+ pad_len = 0
326
+ else:
327
+ pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
328
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
329
+ attention_mask = self.prefill_attention_mask.clone()
330
+ cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
331
+
332
+ query_length = prompt_min_len
333
+ else:
334
+ cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
335
+ attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
336
+ attention_mask[:, :, :, : cache_position + 1] = 1
337
+ input_ids = input_ids[:, -1:].contiguous()
338
+ query_length = 1
339
+
340
+ model_inputs = {
341
+ "input_ids": input_ids,
342
+ "past_key_values": past_key_values,
343
+ "attention_mask": attention_mask,
344
+ # below are rbln-related kwargs
345
+ "cache_position": cache_position,
346
+ "query_length": query_length,
347
+ }
348
+
349
+ return model_inputs
350
+
351
+ def forward(
352
+ self,
353
+ input_ids: Optional[torch.LongTensor] = None,
354
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
355
+ attention_mask: Optional[torch.FloatTensor] = None,
356
+ cache_position: Optional[torch.Tensor] = None,
357
+ query_length: Optional[torch.Tensor] = None,
358
+ **kwargs,
359
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
360
+
361
+ if past_key_values is not None:
362
+ past_key_values += query_length
363
+
364
+ if cache_position == 0:
365
+ for _ in range(0, query_length, self.prefill_chunk_size):
366
+ sliced_input_ids = input_ids[:, cache_position : cache_position + self.prefill_chunk_size]
367
+ attention_mask[:, :, :, :cache_position] = 1
368
+ attention_mask[:, :, :, cache_position : cache_position + self.prefill_chunk_size] = self.causal_mask
369
+
370
+ output = self.prefill_decoder(
371
+ input_ids=sliced_input_ids.contiguous(),
372
+ attention_mask=attention_mask.contiguous(),
373
+ cache_position=cache_position,
374
+ )
375
+ query_length -= self.prefill_chunk_size
376
+ cache_position += self.prefill_chunk_size
377
+
378
+ output = output.logits[:, query_length - 1].unsqueeze(1)
379
+
380
+ else:
381
+ output = self.decoder(
382
+ input_ids=input_ids.contiguous(),
383
+ attention_mask=attention_mask.contiguous(),
384
+ cache_position=cache_position,
385
+ )
386
+ output = output.logits
387
+
388
+ return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
389
+
390
+ def __repr__(self):
391
+ return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
392
+
393
+ # call 'greedy_search` directly is deprecated and removed in v4.41.
394
+ def greedy_search(self, *args, **kwargs):
395
+ return self._greedy_search(*args, **kwargs)
396
+
397
+ def _greedy_search(
398
+ self,
399
+ input_ids: torch.LongTensor,
400
+ logits_processor: Optional[LogitsProcessorList] = None,
401
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
402
+ max_length: Optional[int] = None,
403
+ pad_token_id: Optional[int] = None,
404
+ eos_token_id: Optional[Union[int, List[int]]] = None,
405
+ output_logits: Optional[bool] = None,
406
+ return_dict_in_generate: Optional[bool] = None,
407
+ streamer: Optional["BaseStreamer"] = None,
408
+ **model_kwargs,
409
+ ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
410
+
411
+ # init values
412
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
413
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
414
+
415
+ if max_length is not None:
416
+ warnings.warn(
417
+ "`max_length` is deprecated in this function, use"
418
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
419
+ UserWarning,
420
+ )
421
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
422
+
423
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
424
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
425
+ if isinstance(eos_token_id, int):
426
+ eos_token_id = [eos_token_id]
427
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
428
+
429
+ return_dict_in_generate = (
430
+ return_dict_in_generate
431
+ if return_dict_in_generate is not None
432
+ else self.generation_config.return_dict_in_generate
433
+ )
434
+
435
+ # init attention / hidden states / scores tuples
436
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
437
+
438
+ # keep track of which sequences are already finished
439
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
440
+
441
+ this_peer_finished = False # used by synced_gpus only
442
+
443
+ while True:
444
+ # prepare model inputs
445
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
446
+
447
+ # forward pass to get next token
448
+ outputs = self(
449
+ **model_inputs,
450
+ return_dict=True,
451
+ )
452
+ next_token_logits = outputs.logits[:, -1, :]
453
+
454
+ # pre-process distribution
455
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
456
+
457
+ # Store scores, attentions and hidden_states when required
458
+ if return_dict_in_generate:
459
+ if output_logits:
460
+ raw_logits += (next_token_logits,)
461
+
462
+ # argmax
463
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
464
+
465
+ # finished sentences should have their next token be a padding token
466
+ if eos_token_id is not None:
467
+ if pad_token_id is None:
468
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
469
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
470
+
471
+ ########################################################################################################
472
+ # thkim change for right-padding batch
473
+ # if min_input_len <= update_idx < max_input_len
474
+ # update validate input_ids[:,update_idx]
475
+ # TODO : raw_logits contains dummy next_token's logits
476
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
477
+ if update_idx < self.rightpad_max_len:
478
+ # update exist input_ids rather than concat
479
+ valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
480
+ input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
481
+ model_kwargs["attention_mask"][valid_indices, update_idx] = 1
482
+
483
+ # dummy next_token -> pad_token_id for streamer
484
+ # in order to skip by 'skip_special_tokens = True"
485
+ dummy_indices = ~valid_indices
486
+ next_tokens[dummy_indices] = pad_token_id
487
+ else:
488
+ ############################################END#########################################################
489
+ # update generated ids, model inputs, and length for next step
490
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
491
+
492
+ model_kwargs = self._update_model_kwargs_for_generation(
493
+ outputs,
494
+ model_kwargs,
495
+ is_encoder_decoder=self.config.is_encoder_decoder,
496
+ )
497
+
498
+ if streamer is not None:
499
+ streamer.put(next_tokens.cpu())
500
+
501
+ # if eos_token was found in one sentence, set sentence to finished
502
+ if eos_token_id_tensor is not None:
503
+ ####################################################################
504
+ # thkim : to do not finish sequence of dummy_decoder of right_padding
505
+ if hasattr(self, "rightpad_max_len"):
506
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
507
+ if update_idx < self.rightpad_max_len:
508
+ next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
509
+ ######################################################################
510
+ unfinished_sequences = unfinished_sequences.mul(
511
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
512
+ )
513
+
514
+ # stop when each sentence is finished
515
+ if unfinished_sequences.max() == 0:
516
+ this_peer_finished = True
517
+
518
+ # stop if we exceed the maximum length
519
+ # thkim : backward compatibility bool vs torch.BoolTensor
520
+ is_stop = stopping_criteria(input_ids, None)
521
+ if isinstance(is_stop, torch.BoolTensor):
522
+ is_stop = torch.all(is_stop)
523
+ if is_stop:
524
+ this_peer_finished = True
525
+
526
+ if this_peer_finished:
527
+ break
528
+
529
+ if streamer is not None:
530
+ streamer.end()
531
+
532
+ if return_dict_in_generate:
533
+ return SampleDecoderOnlyOutput(
534
+ sequences=input_ids,
535
+ logits=raw_logits,
536
+ )
537
+ else:
538
+ return input_ids
539
+
540
+ # call 'sample` directly is deprecated and removed in v4.41.
541
+ def sample(self, *args, **kwargs):
542
+ return self._sample(*args, **kwargs)
543
+
544
+ def _sample(
545
+ self,
546
+ input_ids: torch.LongTensor,
547
+ logits_processor: Optional[LogitsProcessorList] = None,
548
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
549
+ logits_warper: Optional[LogitsProcessorList] = None,
550
+ max_length: Optional[int] = None,
551
+ pad_token_id: Optional[int] = None,
552
+ eos_token_id: Optional[Union[int, List[int]]] = None,
553
+ output_attentions: Optional[bool] = None,
554
+ output_hidden_states: Optional[bool] = None,
555
+ output_scores: Optional[bool] = None,
556
+ output_logits: Optional[bool] = None,
557
+ return_dict_in_generate: Optional[bool] = None,
558
+ synced_gpus: bool = False,
559
+ streamer: Optional["BaseStreamer"] = None,
560
+ **model_kwargs,
561
+ ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
562
+ # init values
563
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
564
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
565
+
566
+ if max_length is not None:
567
+ warnings.warn(
568
+ "`max_length` is deprecated in this function, use"
569
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
570
+ UserWarning,
571
+ )
572
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
573
+
574
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
575
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
576
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
577
+
578
+ if isinstance(eos_token_id, int):
579
+ eos_token_id = [eos_token_id]
580
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
581
+
582
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
583
+ output_logits = output_logits if output_logits is not None else False
584
+
585
+ # init attention / hidden states / scores tuples
586
+ scores = () if (return_dict_in_generate and output_scores) else None
587
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
588
+
589
+ # keep track of which sequences are already finished
590
+ batch_size, cur_len = input_ids.shape
591
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
592
+ this_peer_finished = False
593
+
594
+ # model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
595
+
596
+ while True:
597
+ # prepare model inputs
598
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
599
+
600
+ # forward pass to get next token
601
+ outputs = self(
602
+ **model_inputs,
603
+ return_dict=True,
604
+ output_attentions=output_attentions,
605
+ output_hidden_states=output_hidden_states,
606
+ )
607
+
608
+ next_token_logits = outputs.logits[:, -1, :]
609
+
610
+ # pre-process distribution
611
+ next_token_scores = logits_processor(input_ids, next_token_logits)
612
+ next_token_scores = logits_warper(input_ids, next_token_scores)
613
+
614
+ # Store scores, attentions and hidden_states when required
615
+ if return_dict_in_generate:
616
+ if output_scores:
617
+ scores += (next_token_scores,)
618
+ if output_logits:
619
+ raw_logits += (next_token_logits,)
620
+
621
+ # sample
622
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
623
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
624
+
625
+ # finished sentences should have their next token be a padding token
626
+ if eos_token_id is not None:
627
+ if pad_token_id is None:
628
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
629
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
630
+
631
+ ########################################################################################################
632
+ # thkim change for right-padding batch
633
+ # if min_input_len <= update_idx < max_input_len
634
+ # update validate input_ids[:,update_idx]
635
+ # TODO : raw_logits contains dummy next_token's logits
636
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
637
+ if update_idx < self.rightpad_max_len:
638
+ # update exist input_ids rather than concat
639
+ valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
640
+ input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
641
+ model_kwargs["attention_mask"][valid_indices, update_idx] = 1
642
+
643
+ # dummy next_token -> pad_token_id for streamer
644
+ # in order to skip by 'skip_special_tokens = True"
645
+ dummy_indices = ~valid_indices
646
+ next_tokens[dummy_indices] = pad_token_id
647
+ else:
648
+ ############################################END#########################################################
649
+
650
+ # update generated ids, model inputs, and length for next step
651
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
652
+
653
+ model_kwargs = self._update_model_kwargs_for_generation(
654
+ outputs,
655
+ model_kwargs,
656
+ is_encoder_decoder=self.config.is_encoder_decoder,
657
+ )
658
+
659
+ if streamer is not None:
660
+ streamer.put(next_tokens.cpu())
661
+
662
+ # if eos_token was found in one sentence, set sentence to finished
663
+ if eos_token_id_tensor is not None:
664
+ ####################################################################
665
+ # thkim : to do not finish sequence of dummy_decoder of right_padding
666
+ if hasattr(self, "rightpad_max_len"):
667
+ update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
668
+ if update_idx < self.rightpad_max_len:
669
+ next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
670
+ ######################################################################
671
+ unfinished_sequences = unfinished_sequences.mul(
672
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
673
+ )
674
+
675
+ # stop when each sentence is finished
676
+ if unfinished_sequences.max() == 0:
677
+ this_peer_finished = True
678
+
679
+ # stop if we exceed the maximum length
680
+ # thkim : backward compatibility bool vs list[bool]
681
+ is_stop = stopping_criteria(input_ids, None)
682
+ if isinstance(is_stop, torch.BoolTensor):
683
+ is_stop = torch.all(is_stop)
684
+ if is_stop:
685
+ this_peer_finished = True
686
+
687
+ if this_peer_finished:
688
+ break
689
+
690
+ if streamer is not None:
691
+ streamer.end()
692
+
693
+ if return_dict_in_generate:
694
+ return SampleDecoderOnlyOutput(
695
+ sequences=input_ids,
696
+ scores=scores,
697
+ logits=raw_logits,
698
+ )
699
+ else:
700
+ return input_ids