optimum-rbln 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
- optimum/rbln/diffusers/models/controlnet.py +93 -23
- optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
- optimum/rbln/diffusers/pipelines/__init__.py +7 -2
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
- optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
- optimum/rbln/modeling_base.py +39 -6
- optimum/rbln/modeling_seq2seq.py +19 -4
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/generation/__init__.py +1 -0
- optimum/rbln/transformers/generation/streamers.py +17 -0
- optimum/rbln/transformers/generation/utils.py +399 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
- optimum/rbln/transformers/models/llama/llama_architecture.py +49 -17
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +759 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +187 -75
- optimum/rbln/transformers/models/midm/__init__.py +32 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +426 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/METADATA +5 -4
- optimum_rbln-0.1.4.dist-info/RECORD +63 -0
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.0.dist-info/RECORD +0 -51
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,6 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
import warnings
|
27
26
|
from pathlib import Path
|
28
27
|
from tempfile import TemporaryDirectory
|
29
28
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
@@ -32,19 +31,13 @@ import rebel
|
|
32
31
|
import torch
|
33
32
|
from optimum.exporters import TasksManager
|
34
33
|
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
35
|
-
from transformers.generation.logits_process import LogitsProcessorList
|
36
|
-
from transformers.generation.stopping_criteria import (
|
37
|
-
StoppingCriteriaList,
|
38
|
-
validate_stopping_criteria,
|
39
|
-
)
|
40
|
-
from transformers.generation.streamers import BaseStreamer
|
41
|
-
from transformers.generation.utils import SampleDecoderOnlyOutput
|
42
34
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
43
35
|
|
44
36
|
from ....modeling_base import RBLNBaseModel
|
45
37
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
46
38
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
47
39
|
from ....utils.save_utils import maybe_save_preprocessors
|
40
|
+
from ...generation.utils import RBLNGenerationMixin
|
48
41
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
49
42
|
|
50
43
|
|
@@ -66,7 +59,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
66
59
|
return Seq2SeqLMOutput(logits=logits)
|
67
60
|
|
68
61
|
|
69
|
-
class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
62
|
+
class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
70
63
|
"""
|
71
64
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
72
65
|
embeddings).
|
@@ -135,6 +128,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
135
128
|
subfolder: str = "",
|
136
129
|
local_files_only: bool = False,
|
137
130
|
trust_remote_code: bool = False,
|
131
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
138
132
|
**kwargs,
|
139
133
|
) -> "RBLNGPT2LMHeadModel":
|
140
134
|
"""
|
@@ -144,8 +138,16 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
144
138
|
if task is None:
|
145
139
|
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
146
140
|
|
147
|
-
|
148
|
-
|
141
|
+
if model_save_dir is None:
|
142
|
+
save_dir = TemporaryDirectory()
|
143
|
+
save_dir_path = Path(save_dir.name)
|
144
|
+
else:
|
145
|
+
save_dir = model_save_dir
|
146
|
+
if isinstance(save_dir, TemporaryDirectory):
|
147
|
+
save_dir_path = Path(model_save_dir.name)
|
148
|
+
else:
|
149
|
+
save_dir_path = Path(model_save_dir)
|
150
|
+
save_dir_path.mkdir(exist_ok=True)
|
149
151
|
|
150
152
|
kwargs.update(
|
151
153
|
{
|
@@ -264,8 +266,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
264
266
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
265
267
|
meta["rbln_pad_token_id"] = rbln_pad_token_id
|
266
268
|
|
267
|
-
if rbln_batch_size is None
|
268
|
-
rbln_batch_size = 1
|
269
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
269
270
|
|
270
271
|
def get_input_info(query_length):
|
271
272
|
return [
|
@@ -320,6 +321,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
320
321
|
self.prompt_ids = input_ids
|
321
322
|
self.rightpad_max_len = cur_len
|
322
323
|
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
324
|
+
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
|
323
325
|
|
324
326
|
if cur_len % self.prefill_chunk_size == 0:
|
325
327
|
pad_len = 0
|
@@ -329,12 +331,12 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
329
331
|
attention_mask = self.prefill_attention_mask.clone()
|
330
332
|
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
331
333
|
|
332
|
-
query_length = prompt_min_len
|
334
|
+
query_length = prompt_min_len.item()
|
333
335
|
else:
|
334
336
|
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
335
337
|
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
336
338
|
attention_mask[:, :, :, : cache_position + 1] = 1
|
337
|
-
input_ids = input_ids[:,
|
339
|
+
input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
|
338
340
|
query_length = 1
|
339
341
|
|
340
342
|
model_inputs = {
|
@@ -357,25 +359,23 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
357
359
|
query_length: Optional[torch.Tensor] = None,
|
358
360
|
**kwargs,
|
359
361
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
360
|
-
|
361
362
|
if past_key_values is not None:
|
362
363
|
past_key_values += query_length
|
363
364
|
|
364
365
|
if cache_position == 0:
|
365
|
-
for
|
366
|
-
sliced_input_ids = input_ids[:,
|
367
|
-
attention_mask[:, :, :, :
|
368
|
-
attention_mask[:, :, :,
|
366
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
367
|
+
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
368
|
+
attention_mask[:, :, :, :step] = 1
|
369
|
+
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
369
370
|
|
370
371
|
output = self.prefill_decoder(
|
371
372
|
input_ids=sliced_input_ids.contiguous(),
|
372
373
|
attention_mask=attention_mask.contiguous(),
|
373
|
-
cache_position=cache_position,
|
374
|
+
cache_position=cache_position + step,
|
374
375
|
)
|
375
|
-
query_length -= self.prefill_chunk_size
|
376
|
-
cache_position += self.prefill_chunk_size
|
377
376
|
|
378
|
-
|
377
|
+
idx = query_length % self.prefill_chunk_size - 1
|
378
|
+
output = output.logits[:, idx].unsqueeze(1)
|
379
379
|
|
380
380
|
else:
|
381
381
|
output = self.decoder(
|
@@ -389,312 +389,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
389
389
|
|
390
390
|
def __repr__(self):
|
391
391
|
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
392
|
-
|
393
|
-
# call 'greedy_search` directly is deprecated and removed in v4.41.
|
394
|
-
def greedy_search(self, *args, **kwargs):
|
395
|
-
return self._greedy_search(*args, **kwargs)
|
396
|
-
|
397
|
-
def _greedy_search(
|
398
|
-
self,
|
399
|
-
input_ids: torch.LongTensor,
|
400
|
-
logits_processor: Optional[LogitsProcessorList] = None,
|
401
|
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
402
|
-
max_length: Optional[int] = None,
|
403
|
-
pad_token_id: Optional[int] = None,
|
404
|
-
eos_token_id: Optional[Union[int, List[int]]] = None,
|
405
|
-
output_logits: Optional[bool] = None,
|
406
|
-
return_dict_in_generate: Optional[bool] = None,
|
407
|
-
streamer: Optional["BaseStreamer"] = None,
|
408
|
-
**model_kwargs,
|
409
|
-
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
410
|
-
|
411
|
-
# init values
|
412
|
-
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
413
|
-
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
414
|
-
|
415
|
-
if max_length is not None:
|
416
|
-
warnings.warn(
|
417
|
-
"`max_length` is deprecated in this function, use"
|
418
|
-
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
419
|
-
UserWarning,
|
420
|
-
)
|
421
|
-
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
422
|
-
|
423
|
-
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
424
|
-
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
425
|
-
if isinstance(eos_token_id, int):
|
426
|
-
eos_token_id = [eos_token_id]
|
427
|
-
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
428
|
-
|
429
|
-
return_dict_in_generate = (
|
430
|
-
return_dict_in_generate
|
431
|
-
if return_dict_in_generate is not None
|
432
|
-
else self.generation_config.return_dict_in_generate
|
433
|
-
)
|
434
|
-
|
435
|
-
# init attention / hidden states / scores tuples
|
436
|
-
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
437
|
-
|
438
|
-
# keep track of which sequences are already finished
|
439
|
-
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
440
|
-
|
441
|
-
this_peer_finished = False # used by synced_gpus only
|
442
|
-
|
443
|
-
while True:
|
444
|
-
# prepare model inputs
|
445
|
-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
446
|
-
|
447
|
-
# forward pass to get next token
|
448
|
-
outputs = self(
|
449
|
-
**model_inputs,
|
450
|
-
return_dict=True,
|
451
|
-
)
|
452
|
-
next_token_logits = outputs.logits[:, -1, :]
|
453
|
-
|
454
|
-
# pre-process distribution
|
455
|
-
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
456
|
-
|
457
|
-
# Store scores, attentions and hidden_states when required
|
458
|
-
if return_dict_in_generate:
|
459
|
-
if output_logits:
|
460
|
-
raw_logits += (next_token_logits,)
|
461
|
-
|
462
|
-
# argmax
|
463
|
-
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
464
|
-
|
465
|
-
# finished sentences should have their next token be a padding token
|
466
|
-
if eos_token_id is not None:
|
467
|
-
if pad_token_id is None:
|
468
|
-
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
469
|
-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
470
|
-
|
471
|
-
########################################################################################################
|
472
|
-
# thkim change for right-padding batch
|
473
|
-
# if min_input_len <= update_idx < max_input_len
|
474
|
-
# update validate input_ids[:,update_idx]
|
475
|
-
# TODO : raw_logits contains dummy next_token's logits
|
476
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
477
|
-
if update_idx < self.rightpad_max_len:
|
478
|
-
# update exist input_ids rather than concat
|
479
|
-
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
480
|
-
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
481
|
-
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
482
|
-
|
483
|
-
# dummy next_token -> pad_token_id for streamer
|
484
|
-
# in order to skip by 'skip_special_tokens = True"
|
485
|
-
dummy_indices = ~valid_indices
|
486
|
-
next_tokens[dummy_indices] = pad_token_id
|
487
|
-
else:
|
488
|
-
############################################END#########################################################
|
489
|
-
# update generated ids, model inputs, and length for next step
|
490
|
-
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
491
|
-
|
492
|
-
model_kwargs = self._update_model_kwargs_for_generation(
|
493
|
-
outputs,
|
494
|
-
model_kwargs,
|
495
|
-
is_encoder_decoder=self.config.is_encoder_decoder,
|
496
|
-
)
|
497
|
-
|
498
|
-
if streamer is not None:
|
499
|
-
streamer.put(next_tokens.cpu())
|
500
|
-
|
501
|
-
# if eos_token was found in one sentence, set sentence to finished
|
502
|
-
if eos_token_id_tensor is not None:
|
503
|
-
####################################################################
|
504
|
-
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
505
|
-
if hasattr(self, "rightpad_max_len"):
|
506
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
507
|
-
if update_idx < self.rightpad_max_len:
|
508
|
-
next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
|
509
|
-
######################################################################
|
510
|
-
unfinished_sequences = unfinished_sequences.mul(
|
511
|
-
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
512
|
-
)
|
513
|
-
|
514
|
-
# stop when each sentence is finished
|
515
|
-
if unfinished_sequences.max() == 0:
|
516
|
-
this_peer_finished = True
|
517
|
-
|
518
|
-
# stop if we exceed the maximum length
|
519
|
-
# thkim : backward compatibility bool vs torch.BoolTensor
|
520
|
-
is_stop = stopping_criteria(input_ids, None)
|
521
|
-
if isinstance(is_stop, torch.BoolTensor):
|
522
|
-
is_stop = torch.all(is_stop)
|
523
|
-
if is_stop:
|
524
|
-
this_peer_finished = True
|
525
|
-
|
526
|
-
if this_peer_finished:
|
527
|
-
break
|
528
|
-
|
529
|
-
if streamer is not None:
|
530
|
-
streamer.end()
|
531
|
-
|
532
|
-
if return_dict_in_generate:
|
533
|
-
return SampleDecoderOnlyOutput(
|
534
|
-
sequences=input_ids,
|
535
|
-
logits=raw_logits,
|
536
|
-
)
|
537
|
-
else:
|
538
|
-
return input_ids
|
539
|
-
|
540
|
-
# call 'sample` directly is deprecated and removed in v4.41.
|
541
|
-
def sample(self, *args, **kwargs):
|
542
|
-
return self._sample(*args, **kwargs)
|
543
|
-
|
544
|
-
def _sample(
|
545
|
-
self,
|
546
|
-
input_ids: torch.LongTensor,
|
547
|
-
logits_processor: Optional[LogitsProcessorList] = None,
|
548
|
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
549
|
-
logits_warper: Optional[LogitsProcessorList] = None,
|
550
|
-
max_length: Optional[int] = None,
|
551
|
-
pad_token_id: Optional[int] = None,
|
552
|
-
eos_token_id: Optional[Union[int, List[int]]] = None,
|
553
|
-
output_attentions: Optional[bool] = None,
|
554
|
-
output_hidden_states: Optional[bool] = None,
|
555
|
-
output_scores: Optional[bool] = None,
|
556
|
-
output_logits: Optional[bool] = None,
|
557
|
-
return_dict_in_generate: Optional[bool] = None,
|
558
|
-
synced_gpus: bool = False,
|
559
|
-
streamer: Optional["BaseStreamer"] = None,
|
560
|
-
**model_kwargs,
|
561
|
-
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
562
|
-
# init values
|
563
|
-
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
564
|
-
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
565
|
-
|
566
|
-
if max_length is not None:
|
567
|
-
warnings.warn(
|
568
|
-
"`max_length` is deprecated in this function, use"
|
569
|
-
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
570
|
-
UserWarning,
|
571
|
-
)
|
572
|
-
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
573
|
-
|
574
|
-
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
575
|
-
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
576
|
-
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
577
|
-
|
578
|
-
if isinstance(eos_token_id, int):
|
579
|
-
eos_token_id = [eos_token_id]
|
580
|
-
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
581
|
-
|
582
|
-
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
583
|
-
output_logits = output_logits if output_logits is not None else False
|
584
|
-
|
585
|
-
# init attention / hidden states / scores tuples
|
586
|
-
scores = () if (return_dict_in_generate and output_scores) else None
|
587
|
-
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
588
|
-
|
589
|
-
# keep track of which sequences are already finished
|
590
|
-
batch_size, cur_len = input_ids.shape
|
591
|
-
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
592
|
-
this_peer_finished = False
|
593
|
-
|
594
|
-
# model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
595
|
-
|
596
|
-
while True:
|
597
|
-
# prepare model inputs
|
598
|
-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
599
|
-
|
600
|
-
# forward pass to get next token
|
601
|
-
outputs = self(
|
602
|
-
**model_inputs,
|
603
|
-
return_dict=True,
|
604
|
-
output_attentions=output_attentions,
|
605
|
-
output_hidden_states=output_hidden_states,
|
606
|
-
)
|
607
|
-
|
608
|
-
next_token_logits = outputs.logits[:, -1, :]
|
609
|
-
|
610
|
-
# pre-process distribution
|
611
|
-
next_token_scores = logits_processor(input_ids, next_token_logits)
|
612
|
-
next_token_scores = logits_warper(input_ids, next_token_scores)
|
613
|
-
|
614
|
-
# Store scores, attentions and hidden_states when required
|
615
|
-
if return_dict_in_generate:
|
616
|
-
if output_scores:
|
617
|
-
scores += (next_token_scores,)
|
618
|
-
if output_logits:
|
619
|
-
raw_logits += (next_token_logits,)
|
620
|
-
|
621
|
-
# sample
|
622
|
-
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
623
|
-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
624
|
-
|
625
|
-
# finished sentences should have their next token be a padding token
|
626
|
-
if eos_token_id is not None:
|
627
|
-
if pad_token_id is None:
|
628
|
-
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
629
|
-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
630
|
-
|
631
|
-
########################################################################################################
|
632
|
-
# thkim change for right-padding batch
|
633
|
-
# if min_input_len <= update_idx < max_input_len
|
634
|
-
# update validate input_ids[:,update_idx]
|
635
|
-
# TODO : raw_logits contains dummy next_token's logits
|
636
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
637
|
-
if update_idx < self.rightpad_max_len:
|
638
|
-
# update exist input_ids rather than concat
|
639
|
-
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
640
|
-
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
641
|
-
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
642
|
-
|
643
|
-
# dummy next_token -> pad_token_id for streamer
|
644
|
-
# in order to skip by 'skip_special_tokens = True"
|
645
|
-
dummy_indices = ~valid_indices
|
646
|
-
next_tokens[dummy_indices] = pad_token_id
|
647
|
-
else:
|
648
|
-
############################################END#########################################################
|
649
|
-
|
650
|
-
# update generated ids, model inputs, and length for next step
|
651
|
-
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
652
|
-
|
653
|
-
model_kwargs = self._update_model_kwargs_for_generation(
|
654
|
-
outputs,
|
655
|
-
model_kwargs,
|
656
|
-
is_encoder_decoder=self.config.is_encoder_decoder,
|
657
|
-
)
|
658
|
-
|
659
|
-
if streamer is not None:
|
660
|
-
streamer.put(next_tokens.cpu())
|
661
|
-
|
662
|
-
# if eos_token was found in one sentence, set sentence to finished
|
663
|
-
if eos_token_id_tensor is not None:
|
664
|
-
####################################################################
|
665
|
-
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
666
|
-
if hasattr(self, "rightpad_max_len"):
|
667
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
668
|
-
if update_idx < self.rightpad_max_len:
|
669
|
-
next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
|
670
|
-
######################################################################
|
671
|
-
unfinished_sequences = unfinished_sequences.mul(
|
672
|
-
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
673
|
-
)
|
674
|
-
|
675
|
-
# stop when each sentence is finished
|
676
|
-
if unfinished_sequences.max() == 0:
|
677
|
-
this_peer_finished = True
|
678
|
-
|
679
|
-
# stop if we exceed the maximum length
|
680
|
-
# thkim : backward compatibility bool vs list[bool]
|
681
|
-
is_stop = stopping_criteria(input_ids, None)
|
682
|
-
if isinstance(is_stop, torch.BoolTensor):
|
683
|
-
is_stop = torch.all(is_stop)
|
684
|
-
if is_stop:
|
685
|
-
this_peer_finished = True
|
686
|
-
|
687
|
-
if this_peer_finished:
|
688
|
-
break
|
689
|
-
|
690
|
-
if streamer is not None:
|
691
|
-
streamer.end()
|
692
|
-
|
693
|
-
if return_dict_in_generate:
|
694
|
-
return SampleDecoderOnlyOutput(
|
695
|
-
sequences=input_ids,
|
696
|
-
scores=scores,
|
697
|
-
logits=raw_logits,
|
698
|
-
)
|
699
|
-
else:
|
700
|
-
return input_ids
|
@@ -36,7 +36,6 @@ from transformers.models.llama.modeling_llama import (
|
|
36
36
|
LlamaForCausalLM,
|
37
37
|
LlamaModel,
|
38
38
|
LlamaRotaryEmbedding,
|
39
|
-
repeat_kv,
|
40
39
|
)
|
41
40
|
|
42
41
|
|
@@ -149,26 +148,41 @@ class _LlamaAttention(LlamaAttention):
|
|
149
148
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
150
149
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
151
150
|
|
151
|
+
# change to remove repeat
|
152
|
+
key_states = key_states.unsqueeze(2)
|
153
|
+
value_states = value_states.unsqueeze(2)
|
154
|
+
query_states = query_states.view(
|
155
|
+
bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
|
156
|
+
)
|
157
|
+
|
152
158
|
if past_key_value is not None:
|
153
159
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
154
160
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
155
161
|
|
156
|
-
|
157
|
-
|
162
|
+
# change to remove repeat
|
163
|
+
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
164
|
+
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
158
165
|
|
159
|
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
166
|
+
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
160
167
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
168
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
|
169
|
+
|
170
|
+
# change to remove repeat
|
171
|
+
# if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
172
|
+
# raise ValueError(
|
173
|
+
# f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
174
|
+
# f" {attn_weights.size()}"
|
175
|
+
# )
|
166
176
|
|
167
177
|
if attention_mask is not None:
|
168
178
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
169
179
|
raise ValueError(
|
170
180
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
171
181
|
)
|
182
|
+
else:
|
183
|
+
# change to remove repeat
|
184
|
+
attention_mask = attention_mask.unsqueeze(2)
|
185
|
+
|
172
186
|
attn_weights = attn_weights + attention_mask
|
173
187
|
|
174
188
|
# upcast attention to fp32
|
@@ -176,6 +190,9 @@ class _LlamaAttention(LlamaAttention):
|
|
176
190
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
177
191
|
attn_output = torch.matmul(attn_weights, value_states)
|
178
192
|
|
193
|
+
# change to remove repeat
|
194
|
+
attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
|
195
|
+
|
179
196
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
180
197
|
raise ValueError(
|
181
198
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
@@ -516,17 +533,32 @@ class RebelDynamicCache(DynamicCache):
|
|
516
533
|
if len(self.key_cache) <= layer_idx:
|
517
534
|
self.key_cache.append(key_states)
|
518
535
|
self.value_cache.append(value_states)
|
536
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
519
537
|
else:
|
520
|
-
|
521
|
-
|
538
|
+
# change to remove repeat
|
539
|
+
# self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
|
540
|
+
# key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
|
541
|
+
# )
|
542
|
+
# self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
|
543
|
+
# value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
|
544
|
+
# )
|
545
|
+
updated_key = (
|
546
|
+
self.key_cache[layer_idx]
|
547
|
+
.unsqueeze(2)
|
548
|
+
.slice_scatter(
|
549
|
+
key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
|
550
|
+
)
|
522
551
|
)
|
523
|
-
|
524
|
-
|
552
|
+
updated_value = (
|
553
|
+
self.value_cache[layer_idx]
|
554
|
+
.unsqueeze(2)
|
555
|
+
.slice_scatter(
|
556
|
+
value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
|
557
|
+
)
|
525
558
|
)
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
559
|
+
self.key_cache[layer_idx] = updated_key.squeeze(2)
|
560
|
+
self.value_cache[layer_idx] = updated_value.squeeze(2)
|
561
|
+
return updated_key, updated_value
|
530
562
|
|
531
563
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
532
564
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|